Spaces:
Runtime error
Runtime error
| import copy | |
| import random | |
| import numpy as np | |
| import torch | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from ..types import AnyExample, AnyViews | |
| def reflect_extrinsics( | |
| extrinsics: Float[Tensor, "*batch 4 4"], | |
| ) -> Float[Tensor, "*batch 4 4"]: | |
| reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) | |
| reflect[0, 0] = -1 | |
| return reflect @ extrinsics @ reflect | |
| def reflect_views(views: AnyViews) -> AnyViews: | |
| if "depth" in views.keys(): | |
| return { | |
| **views, | |
| "image": views["image"].flip(-1), | |
| "extrinsics": reflect_extrinsics(views["extrinsics"]), | |
| "depth": views["depth"].flip(-1), | |
| } | |
| else: | |
| return { | |
| **views, | |
| "image": views["image"].flip(-1), | |
| "extrinsics": reflect_extrinsics(views["extrinsics"]), | |
| } | |
| def apply_augmentation_shim( | |
| example: AnyExample, | |
| generator: torch.Generator | None = None, | |
| ) -> AnyExample: | |
| """Randomly augment the training images.""" | |
| # Do not augment with 50% chance. | |
| if torch.rand(tuple(), generator=generator) < 0.5: | |
| return example | |
| return { | |
| **example, | |
| "context": reflect_views(example["context"]), | |
| "target": reflect_views(example["target"]), | |
| } | |
| def rotate_90_degrees( | |
| image: torch.Tensor, depth_map: torch.Tensor | None, extri_opencv: torch.Tensor, intri_opencv: torch.Tensor, clockwise=True | |
| ): | |
| """ | |
| Rotates the input image, depth map, and camera parameters by 90 degrees. | |
| Applies one of two 90-degree rotations: | |
| - Clockwise | |
| - Counterclockwise (if clockwise=False) | |
| The extrinsic and intrinsic matrices are adjusted accordingly to maintain | |
| correct camera geometry. | |
| Args: | |
| image (torch.Tensor): | |
| Input image tensor of shape (C, H, W). | |
| depth_map (torch.Tensor or None): | |
| Depth map tensor of shape (H, W), or None if not available. | |
| extri_opencv (torch.Tensor): | |
| Extrinsic matrix (3x4) in OpenCV convention. | |
| intri_opencv (torch.Tensor): | |
| Intrinsic matrix (3x3). | |
| clockwise (bool): | |
| If True, rotates the image 90 degrees clockwise; else 90 degrees counterclockwise. | |
| Returns: | |
| tuple: | |
| ( | |
| rotated_image, | |
| rotated_depth_map, | |
| new_extri_opencv, | |
| new_intri_opencv | |
| ) | |
| Where each is the updated version after the rotation. | |
| """ | |
| image_height, image_width = image.shape[-2:] | |
| # Rotate the image and depth map | |
| rotated_image, rotated_depth_map = rotate_image_and_depth_rot90(image, depth_map, clockwise) | |
| # Adjust the intrinsic matrix | |
| new_intri_opencv = adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise) | |
| # Adjust the extrinsic matrix | |
| new_extri_opencv = adjust_extrinsic_matrix_rot90(extri_opencv, clockwise) | |
| return ( | |
| rotated_image, | |
| rotated_depth_map, | |
| new_extri_opencv, | |
| new_intri_opencv, | |
| ) | |
| def rotate_image_and_depth_rot90(image: torch.Tensor, depth_map: torch.Tensor | None, clockwise: bool): | |
| """ | |
| Rotates the given image and depth map by 90 degrees (clockwise or counterclockwise). | |
| Args: | |
| image (torch.Tensor): | |
| Input image tensor of shape (C, H, W). | |
| depth_map (torch.Tensor or None): | |
| Depth map tensor of shape (H, W), or None if not available. | |
| clockwise (bool): | |
| If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise. | |
| Returns: | |
| tuple: | |
| (rotated_image, rotated_depth_map) | |
| """ | |
| rotated_depth_map = None | |
| if clockwise: | |
| rotated_image = torch.rot90(image, k=-1, dims=[-2, -1]) | |
| if depth_map is not None: | |
| rotated_depth_map = torch.rot90(depth_map, k=-1, dims=[-2, -1]) | |
| else: | |
| rotated_image = torch.rot90(image, k=1, dims=[-2, -1]) | |
| if depth_map is not None: | |
| rotated_depth_map = torch.rot90(depth_map, k=1, dims=[-2, -1]) | |
| return rotated_image, rotated_depth_map | |
| def adjust_extrinsic_matrix_rot90(extri_opencv: torch.Tensor, clockwise: bool): | |
| """ | |
| Adjusts the extrinsic matrix (3x4) for a 90-degree rotation of the image. | |
| The rotation is in the image plane. This modifies the camera orientation | |
| accordingly. The function applies either a clockwise or counterclockwise | |
| 90-degree rotation. | |
| Args: | |
| extri_opencv (torch.Tensor): | |
| Extrinsic matrix (3x4) in OpenCV convention. | |
| clockwise (bool): | |
| If True, rotate extrinsic for a 90-degree clockwise image rotation; | |
| otherwise, counterclockwise. | |
| Returns: | |
| torch.Tensor: | |
| A new 3x4 extrinsic matrix after the rotation. | |
| """ | |
| R = extri_opencv[:3, :3] | |
| t = extri_opencv[:3, 3] | |
| if clockwise: | |
| R_rotation = torch.tensor([ | |
| [0, -1, 0], | |
| [1, 0, 0], | |
| [0, 0, 1] | |
| ], dtype=extri_opencv.dtype, device=extri_opencv.device) | |
| else: | |
| R_rotation = torch.tensor([ | |
| [0, 1, 0], | |
| [-1, 0, 0], | |
| [0, 0, 1] | |
| ], dtype=extri_opencv.dtype, device=extri_opencv.device) | |
| new_R = torch.matmul(R_rotation, R) | |
| new_t = torch.matmul(R_rotation, t) | |
| new_extri_opencv = torch.cat((new_R, new_t.reshape(-1, 1)), dim=1) | |
| new_extri_opencv = torch.cat((new_extri_opencv, | |
| torch.tensor([[0, 0, 0, 1]], | |
| dtype=extri_opencv.dtype, device=extri_opencv.device)), dim=0) | |
| return new_extri_opencv | |
| def adjust_intrinsic_matrix_rot90(intri_opencv: torch.Tensor, image_width: int, image_height: int, clockwise: bool): | |
| """ | |
| Adjusts the intrinsic matrix (3x3) for a 90-degree rotation of the image in the image plane. | |
| Args: | |
| intri_opencv (torch.Tensor): | |
| Intrinsic matrix (3x3). | |
| image_width (int): | |
| Original width of the image. | |
| image_height (int): | |
| Original height of the image. | |
| clockwise (bool): | |
| If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise. | |
| Returns: | |
| torch.Tensor: | |
| A new 3x3 intrinsic matrix after the rotation. | |
| """ | |
| intri_opencv = copy.deepcopy(intri_opencv) | |
| intri_opencv[0, :] *= image_width | |
| intri_opencv[1, :] *= image_height | |
| fx, fy, cx, cy = ( | |
| intri_opencv[0, 0], | |
| intri_opencv[1, 1], | |
| intri_opencv[0, 2], | |
| intri_opencv[1, 2], | |
| ) | |
| new_intri_opencv = torch.eye(3, dtype=intri_opencv.dtype, device=intri_opencv.device) | |
| if clockwise: | |
| new_intri_opencv[0, 0] = fy | |
| new_intri_opencv[1, 1] = fx | |
| new_intri_opencv[0, 2] = image_height - cy | |
| new_intri_opencv[1, 2] = cx | |
| else: | |
| new_intri_opencv[0, 0] = fy | |
| new_intri_opencv[1, 1] = fx | |
| new_intri_opencv[0, 2] = cy | |
| new_intri_opencv[1, 2] = image_width - cx | |
| new_intri_opencv[0, :] /= image_height | |
| new_intri_opencv[1, :] /= image_width | |
| return new_intri_opencv | |