| import copy |
| import random |
| import numpy as np |
| import torch |
| from jaxtyping import Float |
| from torch import Tensor |
|
|
| from ..types import AnyExample, AnyViews |
|
|
|
|
| def reflect_extrinsics( |
| extrinsics: Float[Tensor, "*batch 4 4"], |
| ) -> Float[Tensor, "*batch 4 4"]: |
| reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) |
| reflect[0, 0] = -1 |
| return reflect @ extrinsics @ reflect |
|
|
|
|
| def reflect_views(views: AnyViews) -> AnyViews: |
| if "depth" in views.keys(): |
| return { |
| **views, |
| "image": views["image"].flip(-1), |
| "extrinsics": reflect_extrinsics(views["extrinsics"]), |
| "depth": views["depth"].flip(-1), |
| } |
| else: |
| return { |
| **views, |
| "image": views["image"].flip(-1), |
| "extrinsics": reflect_extrinsics(views["extrinsics"]), |
| } |
|
|
|
|
| def apply_augmentation_shim( |
| example: AnyExample, |
| generator: torch.Generator | None = None, |
| ) -> AnyExample: |
| """Randomly augment the training images.""" |
| |
| if torch.rand(tuple(), generator=generator) < 0.5: |
| return example |
| |
| return { |
| **example, |
| "context": reflect_views(example["context"]), |
| "target": reflect_views(example["target"]), |
| } |
|
|
| def rotate_90_degrees( |
| image: torch.Tensor, depth_map: torch.Tensor | None, extri_opencv: torch.Tensor, intri_opencv: torch.Tensor, clockwise=True |
| ): |
| """ |
| Rotates the input image, depth map, and camera parameters by 90 degrees. |
| |
| Applies one of two 90-degree rotations: |
| - Clockwise |
| - Counterclockwise (if clockwise=False) |
| |
| The extrinsic and intrinsic matrices are adjusted accordingly to maintain |
| correct camera geometry. |
| |
| Args: |
| image (torch.Tensor): |
| Input image tensor of shape (C, H, W). |
| depth_map (torch.Tensor or None): |
| Depth map tensor of shape (H, W), or None if not available. |
| extri_opencv (torch.Tensor): |
| Extrinsic matrix (3x4) in OpenCV convention. |
| intri_opencv (torch.Tensor): |
| Intrinsic matrix (3x3). |
| clockwise (bool): |
| If True, rotates the image 90 degrees clockwise; else 90 degrees counterclockwise. |
| |
| Returns: |
| tuple: |
| ( |
| rotated_image, |
| rotated_depth_map, |
| new_extri_opencv, |
| new_intri_opencv |
| ) |
| |
| Where each is the updated version after the rotation. |
| """ |
| image_height, image_width = image.shape[-2:] |
| |
| |
| rotated_image, rotated_depth_map = rotate_image_and_depth_rot90(image, depth_map, clockwise) |
| |
| new_intri_opencv = adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise) |
| |
| new_extri_opencv = adjust_extrinsic_matrix_rot90(extri_opencv, clockwise) |
|
|
| return ( |
| rotated_image, |
| rotated_depth_map, |
| new_extri_opencv, |
| new_intri_opencv, |
| ) |
|
|
|
|
| def rotate_image_and_depth_rot90(image: torch.Tensor, depth_map: torch.Tensor | None, clockwise: bool): |
| """ |
| Rotates the given image and depth map by 90 degrees (clockwise or counterclockwise). |
| |
| Args: |
| image (torch.Tensor): |
| Input image tensor of shape (C, H, W). |
| depth_map (torch.Tensor or None): |
| Depth map tensor of shape (H, W), or None if not available. |
| clockwise (bool): |
| If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise. |
| |
| Returns: |
| tuple: |
| (rotated_image, rotated_depth_map) |
| """ |
| rotated_depth_map = None |
| if clockwise: |
| rotated_image = torch.rot90(image, k=-1, dims=[-2, -1]) |
| if depth_map is not None: |
| rotated_depth_map = torch.rot90(depth_map, k=-1, dims=[-2, -1]) |
| else: |
| rotated_image = torch.rot90(image, k=1, dims=[-2, -1]) |
| if depth_map is not None: |
| rotated_depth_map = torch.rot90(depth_map, k=1, dims=[-2, -1]) |
| return rotated_image, rotated_depth_map |
|
|
|
|
| def adjust_extrinsic_matrix_rot90(extri_opencv: torch.Tensor, clockwise: bool): |
| """ |
| Adjusts the extrinsic matrix (3x4) for a 90-degree rotation of the image. |
| |
| The rotation is in the image plane. This modifies the camera orientation |
| accordingly. The function applies either a clockwise or counterclockwise |
| 90-degree rotation. |
| |
| Args: |
| extri_opencv (torch.Tensor): |
| Extrinsic matrix (3x4) in OpenCV convention. |
| clockwise (bool): |
| If True, rotate extrinsic for a 90-degree clockwise image rotation; |
| otherwise, counterclockwise. |
| |
| Returns: |
| torch.Tensor: |
| A new 3x4 extrinsic matrix after the rotation. |
| """ |
| R = extri_opencv[:3, :3] |
| t = extri_opencv[:3, 3] |
|
|
| if clockwise: |
| R_rotation = torch.tensor([ |
| [0, -1, 0], |
| [1, 0, 0], |
| [0, 0, 1] |
| ], dtype=extri_opencv.dtype, device=extri_opencv.device) |
| else: |
| R_rotation = torch.tensor([ |
| [0, 1, 0], |
| [-1, 0, 0], |
| [0, 0, 1] |
| ], dtype=extri_opencv.dtype, device=extri_opencv.device) |
|
|
| new_R = torch.matmul(R_rotation, R) |
| new_t = torch.matmul(R_rotation, t) |
| new_extri_opencv = torch.cat((new_R, new_t.reshape(-1, 1)), dim=1) |
| new_extri_opencv = torch.cat((new_extri_opencv, |
| torch.tensor([[0, 0, 0, 1]], |
| dtype=extri_opencv.dtype, device=extri_opencv.device)), dim=0) |
| return new_extri_opencv |
|
|
|
|
| def adjust_intrinsic_matrix_rot90(intri_opencv: torch.Tensor, image_width: int, image_height: int, clockwise: bool): |
| """ |
| Adjusts the intrinsic matrix (3x3) for a 90-degree rotation of the image in the image plane. |
| |
| Args: |
| intri_opencv (torch.Tensor): |
| Intrinsic matrix (3x3). |
| image_width (int): |
| Original width of the image. |
| image_height (int): |
| Original height of the image. |
| clockwise (bool): |
| If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise. |
| |
| Returns: |
| torch.Tensor: |
| A new 3x3 intrinsic matrix after the rotation. |
| """ |
| intri_opencv = copy.deepcopy(intri_opencv) |
| intri_opencv[0, :] *= image_width |
| intri_opencv[1, :] *= image_height |
|
|
| fx, fy, cx, cy = ( |
| intri_opencv[0, 0], |
| intri_opencv[1, 1], |
| intri_opencv[0, 2], |
| intri_opencv[1, 2], |
| ) |
|
|
| new_intri_opencv = torch.eye(3, dtype=intri_opencv.dtype, device=intri_opencv.device) |
| if clockwise: |
| new_intri_opencv[0, 0] = fy |
| new_intri_opencv[1, 1] = fx |
| new_intri_opencv[0, 2] = image_height - cy |
| new_intri_opencv[1, 2] = cx |
| else: |
| new_intri_opencv[0, 0] = fy |
| new_intri_opencv[1, 1] = fx |
| new_intri_opencv[0, 2] = cy |
| new_intri_opencv[1, 2] = image_width - cx |
| |
| new_intri_opencv[0, :] /= image_height |
| new_intri_opencv[1, :] /= image_width |
|
|
| return new_intri_opencv |
|
|