Spaces:
Runtime error
Runtime error
| import random | |
| import numpy as np | |
| import torch | |
| from einops import rearrange | |
| from jaxtyping import Float | |
| from PIL import Image | |
| from torch import Tensor | |
| import torchvision.transforms.functional as F | |
| import cv2 | |
| from ..types import AnyExample, AnyViews | |
| def rescale( | |
| image: Float[Tensor, "3 h_in w_in"], | |
| shape: tuple[int, int], | |
| ) -> Float[Tensor, "3 h_out w_out"]: | |
| h, w = shape | |
| image_new = (image * 255).clip(min=0, max=255).type(torch.uint8) | |
| image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy() | |
| image_new = Image.fromarray(image_new) | |
| image_new = image_new.resize((w, h), Image.LANCZOS) | |
| image_new = np.array(image_new) / 255 | |
| image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device) | |
| return rearrange(image_new, "h w c -> c h w") | |
| def rescale_depth( | |
| depth: Float[Tensor, "1 h w"], | |
| shape: tuple[int, int], | |
| ) -> Float[Tensor, "1 h_out w_out"]: | |
| h, w = shape | |
| depth_new = depth.detach().cpu().numpy() | |
| depth_new = cv2.resize(depth_new, (w,h), interpolation=cv2.INTER_NEAREST) | |
| depth_new = torch.from_numpy(depth_new).to(depth.device) | |
| return depth_new | |
| def center_crop( | |
| images: Float[Tensor, "*#batch c h w"], | |
| intrinsics: Float[Tensor, "*#batch 3 3"], | |
| shape: tuple[int, int], | |
| depths: Float[Tensor, "*#batch 1 h w"] | None = None, | |
| ) -> tuple[ | |
| Float[Tensor, "*#batch c h_out w_out"], # updated images | |
| Float[Tensor, "*#batch 3 3"], # updated intrinsics | |
| Float[Tensor, "*#batch 1 h_out w_out"] | None, # updated depths | |
| ]: | |
| *_, h_in, w_in = images.shape | |
| h_out, w_out = shape | |
| # Note that odd input dimensions induce half-pixel misalignments. | |
| row = (h_in - h_out) // 2 | |
| col = (w_in - w_out) // 2 | |
| # Center-crop the image. | |
| images = images[..., :, row : row + h_out, col : col + w_out] | |
| if depths is not None: | |
| depths = depths[..., row : row + h_out, col : col + w_out] | |
| # Adjust the intrinsics to account for the cropping. | |
| intrinsics = intrinsics.clone() | |
| intrinsics[..., 0, 0] *= w_in / w_out # fx | |
| intrinsics[..., 1, 1] *= h_in / h_out # fy | |
| if depths is not None: | |
| return images, intrinsics, depths | |
| else: | |
| return images, intrinsics | |
| def rescale_and_crop( | |
| images: Float[Tensor, "*#batch c h w"], | |
| intrinsics: Float[Tensor, "*#batch 3 3"], | |
| shape: tuple[int, int], | |
| intr_aug: bool = False, | |
| scale_range: tuple[float, float] = (0.77, 1.0), | |
| depths: Float[Tensor, "*#batch 1 h w"] | None = None, | |
| ) -> tuple[ | |
| Float[Tensor, "*#batch c h_out w_out"], # updated images | |
| Float[Tensor, "*#batch 3 3"], # updated intrinsics | |
| Float[Tensor, "*#batch 1 h_out w_out"] | None, # updated depths | |
| ]: | |
| if type(images) == list: | |
| images_new = [] | |
| intrinsics_new = [] | |
| for i in range(len(images)): | |
| image = images[i] | |
| intrinsic = intrinsics[i] | |
| *_, h_in, w_in = image.shape | |
| h_out, w_out = shape | |
| scale_factor = max(h_out / h_in, w_out / w_in) | |
| h_scaled = round(h_in * scale_factor) | |
| w_scaled = round(w_in * scale_factor) | |
| image = F.resize(image, (h_scaled, w_scaled)) | |
| image = F.center_crop(image, (h_out, w_out)) | |
| images_new.append(image) | |
| intrinsic_new = intrinsic.clone() | |
| intrinsic_new[..., 0, 0] *= w_scaled / w_in # fx | |
| intrinsic_new[..., 1, 1] *= h_scaled / h_in # fy | |
| intrinsics_new.append(intrinsic_new) | |
| if depths is not None: | |
| depths_new = [] | |
| for i in range(len(depths)): | |
| depth = depths[i] | |
| depth = rescale_depth(depth, (h_out, w_out)) | |
| depth = F.center_crop(depth, (h_out, w_out)) | |
| depths_new.append(depth) | |
| return torch.stack(images_new), torch.stack(intrinsics_new), torch.stack(depths_new) | |
| else: | |
| return torch.stack(images_new), torch.stack(intrinsics_new) | |
| else: | |
| # we only support intr_aug for clean datasets | |
| *_, h_in, w_in = images.shape | |
| h_out, w_out = shape | |
| # assert h_out <= h_in and w_out <= w_in # to avoid the case that the image is too small, like co3d | |
| if intr_aug: | |
| scale = random.uniform(*scale_range) | |
| h_scale = round(h_out * scale) | |
| w_scale = round(w_out * scale) | |
| else: | |
| h_scale = h_out | |
| w_scale = w_out | |
| scale_factor = max(h_scale / h_in, w_scale / w_in) | |
| h_scaled = round(h_in * scale_factor) | |
| w_scaled = round(w_in * scale_factor) | |
| assert h_scaled == h_scale or w_scaled == w_scale | |
| # Reshape the images to the correct size. Assume we don't have to worry about | |
| # changing the intrinsics based on how the images are rounded. | |
| *batch, c, h, w = images.shape | |
| images = images.reshape(-1, c, h, w) | |
| images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images]) | |
| images = images.reshape(*batch, c, h_scaled, w_scaled) | |
| if depths is not None: | |
| if type(depths) == list: | |
| depths_new = [] | |
| for i in range(len(depths)): | |
| depth = depths[i] | |
| depth = rescale_depth(depth, (h_scaled, w_scaled)) | |
| depths_new.append(depth) | |
| depths = torch.stack(depths_new) | |
| else: | |
| depths = depths.reshape(-1, h, w) | |
| depths = torch.stack([rescale_depth(depth, (h_scaled, w_scaled)) for depth in depths]) | |
| depths = depths.reshape(*batch, h_scaled, w_scaled) | |
| images, intrinsics, depths = center_crop(images, intrinsics, (h_scale, w_scale), depths) | |
| if intr_aug: | |
| images = F.resize(images, size=(h_out, w_out), interpolation=F.InterpolationMode.BILINEAR) | |
| depths = F.resize(depths, size=(h_out, w_out), interpolation=F.InterpolationMode.NEAREST) | |
| return images, intrinsics, depths | |
| else: | |
| images, intrinsics = center_crop(images, intrinsics, (h_scale, w_scale)) | |
| if intr_aug: | |
| images = F.resize(images, size=(h_out, w_out)) | |
| return images, intrinsics | |
| def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int], intr_aug: bool = False) -> AnyViews: | |
| if "depth" in views.keys(): | |
| images, intrinsics, depths = rescale_and_crop(views["image"], views["intrinsics"], shape, depths=views["depth"], intr_aug=intr_aug) | |
| return { | |
| **views, | |
| "image": images, | |
| "intrinsics": intrinsics, | |
| "depth": depths, | |
| } | |
| else: | |
| images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape, intr_aug) | |
| return { | |
| **views, | |
| "image": images, | |
| "intrinsics": intrinsics, | |
| } | |
| def apply_crop_shim(example: AnyExample, shape: tuple[int, int], intr_aug: bool = False) -> AnyExample: | |
| """Crop images in the example.""" | |
| return { | |
| **example, | |
| "context": apply_crop_shim_to_views(example["context"], shape, intr_aug), | |
| "target": apply_crop_shim_to_views(example["target"], shape, intr_aug), | |
| } | |