Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List, Optional, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| class ImagePairsManipulationBase: | |
| def __init__(self): | |
| pass | |
| def __call__( | |
| self, | |
| img0: torch.Tensor, | |
| img1: torch.Tensor, | |
| img0_region_source: torch.Tensor, | |
| img1_region_source: torch.Tensor, | |
| img0_region_representation: torch.Tensor, | |
| img1_region_representation: torch.Tensor, | |
| ): | |
| """ | |
| Apply resizing, cropping, and padding to image pairs while recording correspondence information. | |
| Args: | |
| - img0: Tensor of shape (B, H, W, C), dtype uint8 representing the first set of images. | |
| - img1: Tensor of shape (B, H, W, C), dtype uint8 representing the second set of images. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| Returns: | |
| - img0: Tensor of image0 after manipulation. | |
| - img1: Tensor of image1 after manipulation. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| """ | |
| raise NotImplementedError | |
| def output_shape(self, H: int, W: int) -> Tuple[int, int]: | |
| """ | |
| Compute the output shape of the image after the resize operation. | |
| Args: | |
| - H: Height of the first image. | |
| - W: Width of the first image. | |
| Returns: | |
| Tuple of (H1, W1, H2, W2) representing the output shape of the images if the manipulation is applied. | |
| """ | |
| raise NotImplementedError | |
| def output_shape_pairs(self, H1: int, W1: int, H2: int, W2: int) -> Tuple[int, int, int, int]: | |
| """ | |
| Compute the output shape of the image after the resize operation. | |
| """ | |
| output1 = self.output_shape(H1, W1) | |
| output2 = self.output_shape(H2, W2) | |
| return output1[0], output1[1], output2[0], output2[1] | |
| def check_input(self, H: int, W: int) -> bool: | |
| """ | |
| Check whether the input shapes are correct for the current manipulation. | |
| Args: | |
| - H: Height of the first image. | |
| - W: Width of the first image. | |
| Returns: | |
| Whether the manipualtion can run on the given input shapes. | |
| """ | |
| raise NotImplementedError | |
| def check_input_pairs(self, H1: int, W1: int, H2: int, W2: int) -> bool: | |
| return self.check_input(H1, W1) and self.check_input(H2, W2) | |
| class ResizeHorizontalAxisManipulation(ImagePairsManipulationBase): | |
| def __init__(self, horizontal_axis: int): | |
| self.horizontal_axis = horizontal_axis | |
| def output_shape(self, H: int, W: int) -> Tuple[int, int]: | |
| """ | |
| Compute the output shape of the image after the resize operation. | |
| """ | |
| resize_ratio = self.horizontal_axis / W | |
| return (int(H * resize_ratio), self.horizontal_axis) | |
| def check_input(self, H: int, W: int) -> bool: | |
| return True | |
| def __call__( | |
| self, | |
| img0: torch.Tensor, | |
| img1: torch.Tensor, | |
| img0_region_source: torch.Tensor, | |
| img1_region_source: torch.Tensor, | |
| img0_region_representation: torch.Tensor, | |
| img1_region_representation: torch.Tensor, | |
| ): | |
| """ | |
| Apply resizing, cropping, and padding to image pairs while recording correspondence information. | |
| Args: | |
| - img0: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the first set of images. | |
| - img1: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the second set of images. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| Returns: | |
| - img0: Tensor of image0 after manipulation. | |
| - img1: Tensor of image1 after manipulation. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| """ | |
| # assert img0.shape == img1.shape, "Image shapes must match" | |
| _, h0, w0, _ = img0.shape | |
| _, h1, w1, _ = img1.shape | |
| target_h0, target_w0, target_h1, target_w1 = self.output_shape_pairs(h0, w0, h1, w1) | |
| assert img0.dtype == img1.dtype, "Image types must match" | |
| is_uint8 = img0.dtype == torch.uint8 | |
| img0_resized = F.interpolate( | |
| img0.permute(0, 3, 1, 2).float(), size=(target_h0, target_w0), mode="bilinear", align_corners=False | |
| ).permute(0, 2, 3, 1) | |
| img1_resized = F.interpolate( | |
| img1.permute(0, 3, 1, 2).float(), size=(target_h1, target_w1), mode="bilinear", align_corners=False | |
| ).permute(0, 2, 3, 1) | |
| if is_uint8: | |
| img0_resized = img0_resized.to(torch.uint8) | |
| img1_resized = img1_resized.to(torch.uint8) | |
| h_mult0 = target_h0 / h0 | |
| w_mult0 = target_w0 / w0 | |
| multplier0 = torch.tensor([h_mult0, h_mult0, w_mult0, w_mult0]) | |
| h_mult1 = target_h1 / h1 | |
| w_mult1 = target_w1 / w1 | |
| multplier1 = torch.tensor([h_mult1, h_mult1, w_mult1, w_mult1]) | |
| # source region is unchanged | |
| # target region is scaled | |
| img0_region_representation = multplier0 * img0_region_representation | |
| img1_region_representation = multplier1 * img1_region_representation | |
| return ( | |
| img0_resized, | |
| img1_resized, | |
| img0_region_source, | |
| img1_region_source, | |
| img0_region_representation, | |
| img1_region_representation, | |
| ) | |
| class ResizeVerticalAxisManipulation(ImagePairsManipulationBase): | |
| def __init__(self, vertical_axis: int): | |
| self.vertical_axis = vertical_axis | |
| def output_shape(self, H: int, W: int) -> Tuple[int, int]: | |
| """ | |
| Compute the output shape of the image after the resize operation. | |
| """ | |
| resize_ratio = self.vertical_axis / H | |
| return (self.vertical_axis, int(W * resize_ratio)) | |
| def check_input(self, H: int, W: int) -> bool: | |
| return True | |
| def __call__( | |
| self, | |
| img0: torch.Tensor, | |
| img1: torch.Tensor, | |
| img0_region_source: torch.Tensor, | |
| img1_region_source: torch.Tensor, | |
| img0_region_representation: torch.Tensor, | |
| img1_region_representation: torch.Tensor, | |
| ): | |
| """ | |
| Apply resizing, cropping, and padding to image pairs while recording correspondence information. | |
| Args: | |
| - img0: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the first set of images. | |
| - img1: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the second set of images. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| Returns: | |
| - img0: Tensor of image0 after manipulation. | |
| - img1: Tensor of image1 after manipulation. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| """ | |
| # assert img0.shape == img1.shape, "Image shapes must match" | |
| _, h0, w0, _ = img0.shape | |
| _, h1, w1, _ = img1.shape | |
| target_h0, target_w0, target_h1, target_w1 = self.output_shape_pairs(h0, w0, h1, w1) | |
| assert img0.dtype == img1.dtype, "Image types must match" | |
| is_uint8 = img0.dtype == torch.uint8 | |
| img0_resized = F.interpolate( | |
| img0.permute(0, 3, 1, 2).float(), size=(target_h0, target_w0), mode="bilinear", align_corners=False | |
| ).permute(0, 2, 3, 1) | |
| img1_resized = F.interpolate( | |
| img1.permute(0, 3, 1, 2).float(), size=(target_h1, target_w1), mode="bilinear", align_corners=False | |
| ).permute(0, 2, 3, 1) | |
| if is_uint8: | |
| img0_resized = img0_resized.to(torch.uint8) | |
| img1_resized = img1_resized.to(torch.uint8) | |
| h_mult0 = target_h0 / h0 | |
| w_mult0 = target_w0 / w0 | |
| multplier0 = torch.tensor([h_mult0, h_mult0, w_mult0, w_mult0]) | |
| h_mult1 = target_h1 / h1 | |
| w_mult1 = target_w1 / w1 | |
| multplier1 = torch.tensor([h_mult1, h_mult1, w_mult1, w_mult1]) | |
| # source region is unchanged | |
| # target region is scaled | |
| img0_region_representation = multplier0 * img0_region_representation | |
| img1_region_representation = multplier1 * img1_region_representation | |
| return ( | |
| img0_resized, | |
| img1_resized, | |
| img0_region_source, | |
| img1_region_source, | |
| img0_region_representation, | |
| img1_region_representation, | |
| ) | |
| class ResizeToFixedManipulation(ImagePairsManipulationBase): | |
| def __init__(self, target_shape: Tuple[int, int]): | |
| self.target_shape = target_shape | |
| def output_shape(self, H: int, W: int) -> Tuple[int, int]: | |
| """ | |
| Compute the output shape of the image after the resize operation. | |
| """ | |
| return self.target_shape | |
| def check_input(self, H: int, W: int) -> bool: | |
| return True | |
| def __call__( | |
| self, | |
| img0: torch.Tensor, | |
| img1: torch.Tensor, | |
| img0_region_source: torch.Tensor, | |
| img1_region_source: torch.Tensor, | |
| img0_region_representation: torch.Tensor, | |
| img1_region_representation: torch.Tensor, | |
| ): | |
| """ | |
| Apply resizing, cropping, and padding to image pairs while recording correspondence information. | |
| Args: | |
| - img0: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the first set of images. | |
| - img1: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the second set of images. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| Returns: | |
| - img0: Tensor of image0 after manipulation. | |
| - img1: Tensor of image1 after manipulation. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| """ | |
| # assert img0.shape == img1.shape, "Image shapes must match" | |
| _, h0, w0, _ = img0.shape | |
| _, h1, w1, _ = img1.shape | |
| target_h0, target_w0, target_h1, target_w1 = self.output_shape_pairs(h0, w0, h1, w1) | |
| assert img0.dtype == img1.dtype, "Image types must match" | |
| is_uint8 = img0.dtype == torch.uint8 | |
| img0_resized = F.interpolate( | |
| img0.permute(0, 3, 1, 2).float(), | |
| size=(target_h0, target_w0), | |
| mode="bilinear", | |
| align_corners=False, | |
| antialias=True, | |
| ).permute(0, 2, 3, 1) | |
| img1_resized = F.interpolate( | |
| img1.permute(0, 3, 1, 2).float(), | |
| size=(target_h1, target_w1), | |
| mode="bilinear", | |
| align_corners=False, | |
| antialias=True, | |
| ).permute(0, 2, 3, 1) | |
| if is_uint8: | |
| img0_resized = img0_resized.to(torch.uint8) | |
| img1_resized = img1_resized.to(torch.uint8) | |
| h_mult0 = target_h0 / h0 | |
| w_mult0 = target_w0 / w0 | |
| multplier0 = torch.tensor([h_mult0, h_mult0, w_mult0, w_mult0]) | |
| h_mult1 = target_h1 / h1 | |
| w_mult1 = target_w1 / w1 | |
| multplier1 = torch.tensor([h_mult1, h_mult1, w_mult1, w_mult1]) | |
| # source region is unchanged | |
| # target region is scaled | |
| img0_region_representation = (multplier0 * img0_region_representation).to(torch.int64) | |
| img1_region_representation = (multplier1 * img1_region_representation).to(torch.int64) | |
| return ( | |
| img0_resized, | |
| img1_resized, | |
| img0_region_source, | |
| img1_region_source, | |
| img0_region_representation, | |
| img1_region_representation, | |
| ) | |
| def scale_axis( | |
| source_low: float, | |
| source_high: float, | |
| reference_low: float, | |
| reference_high: float, | |
| reference_low_new: float, | |
| reference_high_new: float, | |
| ): | |
| reference_length = reference_high - reference_low | |
| coordinate_relative_low = (reference_low_new - reference_low) / reference_length | |
| coordinate_relative_high = (reference_high_new - reference_low) / reference_length | |
| source_length = source_high - source_low | |
| source_low_new = source_low + coordinate_relative_low * source_length | |
| source_high_new = source_low + coordinate_relative_high * source_length | |
| return source_low_new, source_high_new | |
| class CenterCropManipulation(ImagePairsManipulationBase): | |
| def __init__(self, target_size: Tuple[int, int]): | |
| self.target_size = target_size | |
| def output_shape(self, H: int, W: int) -> Tuple[int, int]: | |
| """ | |
| Compute the output shape of the image after the resize operation. | |
| """ | |
| return self.target_size | |
| def check_input(self, H: int, W: int) -> bool: | |
| return H >= self.target_size[0] and W >= self.target_size[1] | |
| def __call__( | |
| self, | |
| img0: torch.Tensor, | |
| img1: torch.Tensor, | |
| img0_region_source: torch.Tensor, | |
| img1_region_source: torch.Tensor, | |
| img0_region_representation: torch.Tensor, | |
| img1_region_representation: torch.Tensor, | |
| ): | |
| """ | |
| Apply resizing, cropping, and padding to image pairs while recording correspondence information. | |
| Args: | |
| - img0: Tensor of shape (B, H, W, C), dtype uint8 representing the first set of images. | |
| - img1: Tensor of shape (B, H, W, C), dtype uint8 representing the second set of images. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| Returns: | |
| - img0: Tensor of image0 after manipulation. | |
| - img1: Tensor of image1 after manipulation. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| """ | |
| B0, H0, W0, C0 = img0.shape | |
| B1, H1, W1, C1 = img1.shape | |
| target_h, target_w = self.target_size | |
| assert H0 >= target_h and W0 >= target_w, "Image shapes must be larger than the target size." | |
| assert H1 >= target_h and W1 >= target_w, "Image shapes must be larger than the target size." | |
| crop_top_0 = (H0 - target_h) // 2 | |
| crop_bottom_0 = H0 - target_h - crop_top_0 | |
| crop_left_0 = (W0 - target_w) // 2 | |
| crop_right_0 = W0 - target_w - crop_left_0 | |
| crop_top_1 = (H1 - target_h) // 2 | |
| crop_bottom_1 = H1 - target_h - crop_top_1 | |
| crop_left_1 = (W1 - target_w) // 2 | |
| crop_right_1 = W1 - target_w - crop_left_1 | |
| # apply the crops | |
| img0_cropped = img0[:, crop_top_0 : H0 - crop_bottom_0, crop_left_0 : W0 - crop_right_0, :] | |
| img1_cropped = img1[:, crop_top_1 : H1 - crop_bottom_1, crop_left_1 : W1 - crop_right_1, :] | |
| # update the representation region accurately. This is complex as we may or may not crop out the valid regions. | |
| remaining_region_0 = torch.tensor( | |
| [ | |
| max(img0_region_representation[0], crop_top_0), | |
| min(img0_region_representation[1], H0 - crop_bottom_0), | |
| max(img0_region_representation[2], crop_left_0), | |
| min(img0_region_representation[3], W0 - crop_right_0), | |
| ] | |
| ) | |
| remaining_region_1 = torch.tensor( | |
| [ | |
| max(img1_region_representation[0], crop_top_1), | |
| min(img1_region_representation[1], H1 - crop_bottom_1), | |
| max(img1_region_representation[2], crop_left_1), | |
| min(img1_region_representation[3], W1 - crop_right_1), | |
| ] | |
| ) | |
| # shift the remaining region as the cropped region is removed | |
| img0_region_representation_new = remaining_region_0 - torch.tensor( | |
| [crop_top_0, crop_top_0, crop_left_0, crop_left_0] | |
| ) | |
| img1_region_representation_new = remaining_region_1 - torch.tensor( | |
| [crop_top_1, crop_top_1, crop_left_1, crop_left_1] | |
| ) | |
| img0_region_representation_new = img0_region_representation_new.to(torch.int64) | |
| img1_region_representation_new = img1_region_representation_new.to(torch.int64) | |
| # the valid region may or may not be cropped out, so we need to adjust the source region as well | |
| img0_region_source[0], img0_region_source[1] = scale_axis( | |
| img0_region_source[0], | |
| img0_region_source[1], | |
| img0_region_representation[0], | |
| img0_region_representation[1], | |
| remaining_region_0[0], | |
| remaining_region_0[1], | |
| ) | |
| img0_region_source[2], img0_region_source[3] = scale_axis( | |
| img0_region_source[2], | |
| img0_region_source[3], | |
| img0_region_representation[2], | |
| img0_region_representation[3], | |
| remaining_region_0[2], | |
| remaining_region_0[3], | |
| ) | |
| img1_region_source[0], img1_region_source[1] = scale_axis( | |
| img1_region_source[0], | |
| img1_region_source[1], | |
| img1_region_representation[0], | |
| img1_region_representation[1], | |
| remaining_region_1[0], | |
| remaining_region_1[1], | |
| ) | |
| img1_region_source[2], img1_region_source[3] = scale_axis( | |
| img1_region_source[2], | |
| img1_region_source[3], | |
| img1_region_representation[2], | |
| img1_region_representation[3], | |
| remaining_region_1[2], | |
| remaining_region_1[3], | |
| ) | |
| return ( | |
| img0_cropped, | |
| img1_cropped, | |
| img0_region_source, | |
| img1_region_source, | |
| img0_region_representation_new, | |
| img1_region_representation_new, | |
| ) | |
| class ImagePairsManipulationComposite(ImagePairsManipulationBase): | |
| def __init__(self, *manipulations: List[ImagePairsManipulationBase]): | |
| self.manipulations = manipulations | |
| def output_shape(self, H: int, W: int) -> Tuple[int, int]: | |
| """ | |
| Compute the output shape of the image after the resize operation. | |
| """ | |
| output_shape = (H, W) | |
| for manipulation in self.manipulations: | |
| output_shape = manipulation.output_shape(*output_shape) | |
| return output_shape | |
| def output_shape_pairs(self, H1: int, W1: int, H2: int, W2: int) -> Tuple[int, int, int, int]: | |
| """ | |
| Compute the output shape of the image after the resize operation. | |
| """ | |
| output_shape = (H1, W1, H2, W2) | |
| for manipulation in self.manipulations: | |
| output_shape = manipulation.output_shape_pairs(*output_shape) | |
| return output_shape | |
| def check_input(self, H: int, W: int) -> bool: | |
| current_shape = (H, W) | |
| for manipulation in self.manipulations: | |
| if not manipulation.check_input(*current_shape): | |
| return False | |
| current_shape = manipulation.output_shape(*current_shape) | |
| return True | |
| def check_input_pairs(self, H1: int, W1: int, H2: int, W2: int) -> bool: | |
| current_shape = (H1, W1, H2, W2) | |
| for manipulation in self.manipulations: | |
| if not manipulation.check_input_pairs(*current_shape): | |
| return False | |
| current_shape = manipulation.output_shape_pairs(*current_shape) | |
| return True | |
| def __call__( | |
| self, | |
| img0: torch.Tensor, | |
| img1: torch.Tensor, | |
| img0_region_source: torch.Tensor, | |
| img1_region_source: torch.Tensor, | |
| img0_region_representation: torch.Tensor, | |
| img1_region_representation: torch.Tensor, | |
| ): # -> tuple[Tensor | Any, Tensor | Any, Tensor | Any, Tensor | ...: | |
| """ | |
| Apply resizing, cropping, and padding to image pairs while recording correspondence information. | |
| Args: | |
| - img0: Tensor of shape (B, H, W, C), dtype uint8 representing the first set of images. | |
| - img1: Tensor of shape (B, H, W, C), dtype uint8 representing the second set of images. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| Returns: | |
| - img0: Tensor of image0 after manipulation. | |
| - img1: Tensor of image1 after manipulation. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| """ | |
| for manipulation in self.manipulations: | |
| ( | |
| img0, | |
| img1, | |
| img0_region_source, | |
| img1_region_source, | |
| img0_region_representation, | |
| img1_region_representation, | |
| ) = manipulation( | |
| img0, | |
| img1, | |
| img0_region_source, | |
| img1_region_source, | |
| img0_region_representation, | |
| img1_region_representation, | |
| ) | |
| return ( | |
| img0, | |
| img1, | |
| img0_region_source, | |
| img1_region_source, | |
| img0_region_representation, | |
| img1_region_representation, | |
| ) | |
| class AutomaticShapeSelection(ImagePairsManipulationBase): | |
| def __init__(self, *manipulations: ImagePairsManipulationBase, strategy="closest_aspect"): | |
| self.manipulations = manipulations | |
| if strategy == "closest_aspect": | |
| self.strategy = self._closest_aspect_strategy | |
| else: | |
| raise ValueError("Unknown strategy") | |
| def output_shape(self, H: int, W: int) -> Tuple[int, int]: | |
| """ | |
| Compute the output shape of the image after the resize operation. | |
| """ | |
| output_shape, augmentor = self.strategy(H, W) | |
| if output_shape is None: | |
| raise ValueError("No valid shape found for the given resolution.") | |
| return output_shape | |
| def output_shape_pairs(self, H1: int, W1: int, H2: int, W2: int) -> Tuple[int, int, int, int]: | |
| """ | |
| Compute the output shape of the image after the resize operation. | |
| """ | |
| output_shape, augmentor = self.strategy(H1, W1, H2, W2) | |
| if output_shape is None: | |
| raise ValueError("No valid shape found for the given resolution.") | |
| return output_shape | |
| def check_input(self, H: int, W: int) -> bool: | |
| output_shape, augmentor = self.strategy(H, W) | |
| if output_shape is None: | |
| return False | |
| return True | |
| def check_input_pairs(self, H1: int, W1: int, H2: int, W2: int) -> bool: | |
| output_shape, augmentor = self.strategy(H1, W1, H2, W2) | |
| if output_shape is None: | |
| return False | |
| return True | |
| def _closest_aspect_strategy(self, H: int, W: int, *shape_img1): | |
| # for all caididate sizes, first check if they can run at the given resolution | |
| if shape_img1 is None: | |
| runnable_sizes = [ | |
| (manipulator.output_shape(H, W, *shape_img1), manipulator) | |
| for manipulator in self.manipulations | |
| if manipulator.check_input(H, W, *shape_img1) | |
| ] | |
| else: | |
| runnable_sizes = [ | |
| (manipulator.output_shape_pairs(H, W, *shape_img1), manipulator) | |
| for manipulator in self.manipulations | |
| if manipulator.check_input_pairs(H, W, *shape_img1) | |
| ] | |
| if len(runnable_sizes) == 0: | |
| return None, None | |
| # if there are runnable sizes, then select the one that is closest to the given resolution | |
| if shape_img1 is None: | |
| closest_size, closest_augmentor = min(runnable_sizes, key=lambda x: abs(x[0][0] / x[0][1] - H / W)) | |
| else: | |
| closest_size, closest_augmentor = min( | |
| runnable_sizes, | |
| key=lambda x: abs(x[0][0] / x[0][1] - H / W) + abs(x[0][2] / x[0][3] - shape_img1[0] / shape_img1[1]), | |
| ) | |
| return closest_size, closest_augmentor | |
| def __call__( | |
| self, | |
| img0: torch.Tensor, | |
| img1: torch.Tensor, | |
| img0_region_source: Optional[torch.Tensor] = None, | |
| img1_region_source: Optional[torch.Tensor] = None, | |
| img0_region_representation: Optional[torch.Tensor] = None, | |
| img1_region_representation: Optional[torch.Tensor] = None, | |
| ): | |
| """ | |
| Apply resizing, cropping, and padding to image pairs while recording correspondence information. | |
| Args: | |
| - img0: Tensor of shape (B, H, W, C), dtype uint8 representing the first set of images. | |
| - img1: Tensor of shape (B, H, W, C), dtype uint8 representing the second set of images. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| Returns: | |
| - img0: Tensor of image0 after manipulation. | |
| - img1: Tensor of image1 after manipulation. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| """ | |
| H0, W0 = img0.shape[1], img0.shape[2] | |
| H1, W1 = img1.shape[1], img1.shape[2] | |
| output_shape, augmentor = self.strategy(H0, W0, H1, W1) | |
| if output_shape is None: | |
| raise ValueError("No valid shape found for the given resolution.") | |
| if img0_region_source is None: | |
| assert img1_region_source is None | |
| assert img0_region_representation is None | |
| assert img1_region_representation is None | |
| img0_region_source = torch.tensor([0, H0, 0, W0]) | |
| img1_region_source = torch.tensor([0, H1, 0, W1]) | |
| img0_region_representation = torch.tensor([0, H0, 0, W0]) | |
| img1_region_representation = torch.tensor([0, H1, 0, W1]) | |
| return augmentor( | |
| img0, img1, img0_region_source, img1_region_source, img0_region_representation, img1_region_representation | |
| ) | |
| # unmap the predicted flow to match the input. Flow is unique semantically as its value changes | |
| # depending on the source and target region. | |
| def unmap_predicted_flow( | |
| flow: torch.Tensor, | |
| img0_region_representation: torch.Tensor, | |
| img1_region_representation: torch.Tensor, | |
| img0_region_source: torch.Tensor, | |
| img1_region_source: torch.Tensor, | |
| img0_source_shape: Tuple[int, int], | |
| img1_source_shape: Tuple[int, int], | |
| ): | |
| """ | |
| Unmap the predicted flow to the original image space. | |
| Args: | |
| - flow: Tensor of shape (B, 2, H, W) representing the predicted flow between the two regions. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| Returns: | |
| - flow: Tensor of shape (B, 2, H, W) representing the predicted flow in the original image space. | |
| """ | |
| B, C, H, W = flow.shape | |
| # Step 1: Zero the start of flow representing mapping in model's output space | |
| # the flow end is the source coordinates + the flow | |
| flow_roi = flow[ | |
| ..., | |
| img0_region_representation[0] : img0_region_representation[1], | |
| img0_region_representation[2] : img0_region_representation[3], | |
| ] | |
| source_offset = torch.tensor([img0_region_source[2], img0_region_source[0]]).to(flow.device) | |
| target_offset = torch.tensor([img1_region_source[2], img1_region_source[0]]).to(flow.device) | |
| flow_valid2valid = flow_roi # + (source_offset - target_offset).view(1, 2, 1, 1) | |
| # Step 2: Represent the flow as pairs of source and target coordinates | |
| source_coordinates = ( | |
| torch.stack( | |
| torch.meshgrid( | |
| torch.arange(0, flow_valid2valid.shape[3]) + 0.5, | |
| torch.arange(0, flow_valid2valid.shape[2]) + 0.5, | |
| indexing="xy", | |
| ), | |
| dim=-1, | |
| ) | |
| .permute(2, 0, 1) | |
| .unsqueeze(0) | |
| .to(flow.device) | |
| ) | |
| # Step 3: Scale the flow to the source space. Notice that here we can actually assume | |
| # valid representation space have the same shape. | |
| # So it looks like both source and target coordinates are scaled according to the source representation. | |
| # now we scale the valid2valid flow from representation space to source space | |
| source_valid_shape = torch.tensor( | |
| [img0_region_source[1] - img0_region_source[0], img0_region_source[3] - img0_region_source[2]] | |
| ) | |
| target_valid_shape = torch.tensor( | |
| [img1_region_source[1] - img1_region_source[0], img1_region_source[3] - img1_region_source[2]] | |
| ) | |
| # upscale source and target coordinates to the source space | |
| source_coordinates_valid = F.interpolate( | |
| source_coordinates.float(), size=source_valid_shape.tolist(), mode="bilinear", align_corners=False | |
| ) | |
| # This is equivalently we define "target_coordinates = source_coordinates + flow_valid2valid" and apply the scaling. | |
| # since we have a flow component, we can only do nearest interpolation, but this will cause ~0.5 pixel error | |
| # because we are interpoling the source_coordinates also linearly. | |
| target_coordinates_valid = ( | |
| F.interpolate(flow_valid2valid.float(), size=source_valid_shape.tolist(), mode="nearest") | |
| + source_coordinates_valid | |
| ) | |
| # print("Change me to nearest interpolation") | |
| # apply different scaling to the flow: representation for source maps to source_valid_shape in source space | |
| source_coordinates_valid *= ( | |
| torch.tensor( | |
| [ | |
| source_valid_shape[1] / (img0_region_representation[3] - img0_region_representation[2]), | |
| source_valid_shape[0] / (img0_region_representation[1] - img0_region_representation[0]), | |
| ] | |
| ) | |
| .view(1, 2, 1, 1) | |
| .to(flow.device) | |
| ) | |
| # target coordinates are scaled to the target source space, which may be different from the source space | |
| target_coordinates_valid *= ( | |
| torch.tensor( | |
| [ | |
| target_valid_shape[1] / (img0_region_representation[3] - img0_region_representation[2]), | |
| target_valid_shape[0] / (img0_region_representation[1] - img0_region_representation[0]), | |
| ] | |
| ) | |
| .view(1, 2, 1, 1) | |
| .to(flow.device) | |
| ) | |
| # Step 4: Offset the flow from valid source space to the original source space | |
| source_coordinates_valid += ( | |
| torch.tensor([img0_region_source[2], img0_region_source[0]]).view(1, 2, 1, 1).to(flow.device) | |
| ) | |
| target_coordinates_valid += ( | |
| torch.tensor([img1_region_source[2], img1_region_source[0]]).view(1, 2, 1, 1).to(flow.device) | |
| ) | |
| # now we can compute the flow in the source space | |
| flow_source = target_coordinates_valid - source_coordinates_valid | |
| # Step5: Embed the flow in its original space | |
| flow_output = torch.zeros((B, 2, img0_source_shape[0], img0_source_shape[1]), dtype=flow.dtype, device=flow.device) | |
| flow_output[ | |
| ..., img0_region_source[0] : img0_region_source[1], img0_region_source[2] : img0_region_source[3] | |
| ] = flow_source | |
| flow_valid = torch.zeros((B, img0_source_shape[0], img0_source_shape[1]), dtype=torch.bool, device=flow.device) | |
| flow_valid[..., img0_region_source[0] : img0_region_source[1], img0_region_source[2] : img0_region_source[3]] = True | |
| return flow_output, flow_valid | |
| # unmap predicted source - target point pairs. | |
| def unmap_predicted_pairs( | |
| source_points: torch.Tensor, | |
| target_points: torch.Tensor, | |
| img0_region_representation: torch.Tensor, | |
| img1_region_representation: torch.Tensor, | |
| img0_region_source: torch.Tensor, | |
| img1_region_source: torch.Tensor, | |
| img0_source_shape: Tuple[int, int], | |
| img1_source_shape: Tuple[int, int], | |
| ): | |
| """ | |
| Unmap the predicted flow to the original image space. | |
| Args: | |
| - source_points: Tensor of shape (B, N, 2) representing the predicted source points. | |
| - target_points: Tensor of shape (B, N, 2) representing the predicted target points. | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. | |
| Returns: | |
| - flow: Tensor of shape (B, 2, H, W) representing the predicted flow in the original image space. | |
| """ | |
| # 1. scale source points & target points from representation space to source space | |
| img0_region_source_shape = torch.tensor( | |
| [img0_region_source[1] - img0_region_source[0], img0_region_source[3] - img0_region_source[2]] | |
| ) | |
| img1_region_source_shape = torch.tensor( | |
| [img1_region_source[1] - img1_region_source[0], img1_region_source[3] - img1_region_source[2]] | |
| ) | |
| source_points[:, :, 0], _ = scale_axis( | |
| img0_region_source[2], | |
| img0_region_source[3], | |
| img0_region_representation[2], | |
| img0_region_representation[3], | |
| source_points[:, :, 0], | |
| 0.0, | |
| ) | |
| source_points[:, :, 1], _ = scale_axis( | |
| img0_region_source[0], | |
| img0_region_source[1], | |
| img0_region_representation[0], | |
| img0_region_representation[1], | |
| source_points[:, :, 1], | |
| 0.0, | |
| ) | |
| target_points[:, :, 0], _ = scale_axis( | |
| img1_region_source[2], | |
| img1_region_source[3], | |
| img1_region_representation[2], | |
| img1_region_representation[3], | |
| target_points[:, :, 0], | |
| 0.0, | |
| ) | |
| target_points[:, :, 1], _ = scale_axis( | |
| img1_region_source[0], | |
| img1_region_source[1], | |
| img1_region_representation[0], | |
| img1_region_representation[1], | |
| target_points[:, :, 1], | |
| 0.0, | |
| ) | |
| return source_points, target_points | |
| # unmap normal channels like confidence, depth, etc. | |
| # much simpler than the flow case | |
| def unmap_predicted_channels( | |
| channel: torch.Tensor, | |
| img0_region_representation: torch.Tensor, | |
| img0_region_source: torch.Tensor, | |
| img0_source_shape: Tuple[int, int], | |
| ): | |
| """ | |
| Unmap the predicted flow to the original image space. | |
| Args: | |
| - channel: Tensor of shape (B, C, H, W) representing the predicted values in img0 representation space | |
| - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. | |
| - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. | |
| - img0_source_shape: Tuple of size 2 representing the shape of the original image. | |
| Returns: | |
| - channel: Tensor of shape (B, C, H, W) representing the predicted flow in the original image space. | |
| - channel_valid: Tensor of shape (B, H, W) representing the valid region of the channel in the original image space. | |
| """ | |
| B, C, H, W = channel.shape | |
| # Step 1: Zero the start of flow representing mapping in model's output space | |
| # the flow end is the source coordinates + the flow | |
| channel_roi = channel[ | |
| ..., | |
| img0_region_representation[0] : img0_region_representation[1], | |
| img0_region_representation[2] : img0_region_representation[3], | |
| ] | |
| # upscale the channel roi into source space roi | |
| img0_valid_shape = torch.tensor( | |
| [img0_region_source[1] - img0_region_source[0], img0_region_source[3] - img0_region_source[2]] | |
| ) | |
| channel_source_roi = F.interpolate( | |
| channel_roi, | |
| size=img0_valid_shape.tolist(), | |
| mode="nearest", | |
| # align_corners=False | |
| ) | |
| channel_output = torch.zeros( | |
| (B, C, img0_source_shape[0], img0_source_shape[1]), dtype=channel.dtype, device=channel.device | |
| ) | |
| channel_output[ | |
| ..., img0_region_source[0] : img0_region_source[1], img0_region_source[2] : img0_region_source[3] | |
| ] = channel_source_roi | |
| channel_valid = torch.zeros( | |
| (B, img0_source_shape[0], img0_source_shape[1]), dtype=torch.bool, device=channel.device | |
| ) | |
| channel_valid[ | |
| ..., img0_region_source[0] : img0_region_source[1], img0_region_source[2] : img0_region_source[3] | |
| ] = True | |
| return channel_output, channel_valid | |
| if __name__ == "__main__": | |
| import sys | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| # make a example test image that have flow in only one pixel from (25%, 25%) to (50%, 75%) of the image. | |
| img0 = torch.zeros((1, 145, 256, 3), dtype=torch.uint8) # one below and one above the aspect (288, 512) | |
| img1 = torch.zeros((1, 135, 256, 3), dtype=torch.uint8) | |
| source_pt = img0.shape[1] * 0.25, img0.shape[2] * 0.25 | |
| target_pt = img1.shape[1] * 0.5, img1.shape[2] * 0.75 | |
| img0[0, int(source_pt[0]), int(source_pt[1]), :] = 255 | |
| img1[0, int(target_pt[0]), int(target_pt[1]), :] = 255 | |
| flow_gt = torch.zeros((1, 2, 145, 256)) | |
| flow_gt[0, :, int(source_pt[0]), int(source_pt[1])] = torch.tensor( | |
| [target_pt[1] - source_pt[1], target_pt[0] - source_pt[0]] | |
| ) | |
| H0, W0 = img0.shape[1], img0.shape[2] | |
| H1, W1 = img1.shape[1], img1.shape[2] | |
| manipulation = AutomaticShapeSelection( | |
| ImagePairsManipulationComposite(ResizeHorizontalAxisManipulation(512), CenterCropManipulation((288, 512))), | |
| ImagePairsManipulationComposite(ResizeHorizontalAxisManipulation(512), CenterCropManipulation((200, 512))), | |
| ) | |
| ( | |
| img0_resized, | |
| img1_resized, | |
| img0_region_source, | |
| img1_region_source, | |
| img0_region_representation, | |
| img1_region_representation, | |
| ) = manipulation(img0, img1) | |
| fig, axs = plt.subplots(2, 3) | |
| axs[0, 0].imshow(img0[0].numpy()) | |
| axs[0, 1].imshow(img0_resized[0].numpy()) | |
| axs[1, 0].imshow(img1[0].numpy()) | |
| axs[1, 1].imshow(img1_resized[0].numpy()) | |
| print(img0_region_source) | |
| print(img1_region_source) | |
| print(img0_region_representation) | |
| print(img1_region_representation) | |
| flow_pred = torch.zeros((1, 2, 288, 512)) | |
| flow_pred[0, :, 28, 128] = torch.tensor([256, 72]) | |
| # unmap the flow | |
| flow_unmapped = unmap_predicted_flow( | |
| flow_pred, | |
| img0_region_representation, | |
| img1_region_representation, | |
| img0_region_source, | |
| img1_region_source, | |
| (H0, W0), | |
| (H1, W1), | |
| ) | |
| flow_unmapped, flow_validity = flow_unmapped | |
| flow_unmapped = flow_unmapped[0] | |
| flow_validity = flow_validity[0] | |
| import flow_vis | |
| flow_rgb = flow_vis.flow_to_color(flow_unmapped.permute(1, 2, 0).numpy()) | |
| axs[0, 2].imshow(flow_validity) | |
| plt.figure() | |
| plt.imshow(flow_rgb) | |
| plt.show() | |