import torch import torch.nn as nn import numpy as np from einops import rearrange from jaxtyping import Float from torch import Tensor from pathlib import Path import os import json from optgs.geometry.projection import get_fov, get_projection_matrix from optgs.visualization.camera_trajectory.wobble import generate_wobble_transformation from optgs.visualization.camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics def get_scene_scale(camtoworlds: Float[np.ndarray, "N 4 4"]) -> float: # camtoworlds: [N, 4, 4] # size of the scene measured by cameras as in gsplat camera_locations = camtoworlds[:, :3, 3] scene_center = np.mean(camera_locations, axis=0) dists = np.linalg.norm(camera_locations - scene_center, axis=1) scene_scale = np.max(dists) return float(scene_scale) * 1.1 class Camera(nn.Module): """ A camera class that stores the camera parameters and the image for Re10k dataset. Attributes: image_name: extrinsics: C2W matrix (4x4 torch.Tensor) intrinsics: K matrix (3x3 torch.Tensor) near: Near clipping plane distance far: Far clipping plane distance image: RGB image (3xHxW torch.Tensor) fov_x: Field of view in x direction fov_y: Field of view in y direction image_heigth: Height of the image image_width: Width of the image view_matrix: View matrix (4x4 torch.Tensor) full_projection_matrix: Full projection matrix (4x4 torch.Tensor) camera_center: Camera center (3 torch.Tensor) """ def __init__( self, colmap_id: str, extrinsics: Float[Tensor, "4 4"], intrinsics: Float[Tensor, "3 3"], extrinsics_render_view: Float[Tensor, "4 4"], intrinsics_render_view: Float[Tensor, "3 3"], scale_matrix: Float[Tensor, "4 4"], trans_matrix: Float[Tensor, "4 4"], image: Float[Tensor, "3 h w"], raw_image_shape: tuple[int, int], image_name: str, uid: int, near: Float[Tensor, "1"], far: Float[Tensor, "1"], data_device: torch.device, gt_alpha_mask: Float[Tensor, "1 h w"] | None = None, trans=np.array([0.0, 0.0, 0.0]), scale=1.0 ): super(Camera, self).__init__() self.idx = -1 self.uid = uid self.colmap_id = colmap_id self.image_name = image_name try: self.data_device = data_device except Exception as e: print(e) print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) self.data_device = torch.device("cuda") self.extrinsics = extrinsics.to(self.data_device) # C2W matrix! (not really extrinsics) self.intrinsics = intrinsics.to(self.data_device) self.extrinsics_render_view = extrinsics_render_view.to(self.data_device) self.intrinsics_render_view = intrinsics_render_view.to(self.data_device) self.scale_matrix = scale_matrix.to(self.data_device) self.trans_matrix = trans_matrix.to(self.data_device) self.raw_image_shape = raw_image_shape self.original_image = image.clamp(0.0, 1.0) self.image_width = self.original_image.shape[2] self.image_height = self.original_image.shape[1] if gt_alpha_mask is not None: # self.original_image *= gt_alpha_mask.to(self.data_device) self.gt_alpha_mask = gt_alpha_mask.to(self.data_device) else: # self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) self.gt_alpha_mask = None self.zfar = far.to(self.data_device) self.znear = near.to(self.data_device) self.trans = trans self.scale = scale fov_x, fov_y = get_fov(self.intrinsics.unsqueeze(0)).unbind(dim=-1) self.FoVx = fov_x.item() self.FoVy = fov_y.item() projection_matrix = get_projection_matrix(self.znear, self.zfar, fov_x, fov_y) projection_matrix = rearrange(projection_matrix, "b i j -> b j i") view_matrix = rearrange(self.extrinsics.inverse(), "i j -> j i") full_projection = (view_matrix.unsqueeze(0) @ projection_matrix)[0] self.camera_center = self.extrinsics[:3, 3] self.projection_matrix = projection_matrix[0].transpose(0, 1) self.world_view_transform = view_matrix self.full_proj_transform = full_projection def save(self, save_dir: Path): cam_dir = save_dir / self.image_name os.makedirs(cam_dir, exist_ok=True) torch.save(self.extrinsics, cam_dir / "extrinsics.pt") torch.save(self.intrinsics, cam_dir / "intrinsics.pt") torch.save(self.original_image, cam_dir / "image.pt") if self.gt_alpha_mask is not None: torch.save(self.gt_alpha_mask, cam_dir / "gt_alpha_mask.pt") with open(cam_dir / "cam_info.json", "w") as f: json.dump( { "colmap_id": self.colmap_id, "image_name": self.image_name, "uid": self.uid, "raw_image_shape": self.raw_image_shape, "near": self.znear.item(), "far": self.zfar.item() }, f, indent=4, ) @classmethod def load_camera(cls, cam_dir: Path, data_device: torch.device): extrinsics = torch.load(cam_dir / "extrinsics.pt") intrinsics = torch.load(cam_dir / "intrinsics.pt") image = torch.load(cam_dir / "image.pt") if (cam_dir / "gt_alpha_mask.pt").exists(): gt_alpha_mask = torch.load(cam_dir / "gt_alpha_mask.pt") else: gt_alpha_mask = None with open(cam_dir / "cam_info.json", "r") as f: cam_info = json.load(f) return cls( colmap_id=cam_info["colmap_id"], extrinsics=extrinsics.to(data_device), intrinsics=intrinsics.to(data_device), image=image.to(data_device), gt_alpha_mask=gt_alpha_mask.to(data_device) if gt_alpha_mask is not None else None, raw_image_shape=tuple(cam_info["raw_image_shape"]), image_name=cam_info["image_name"], uid=cam_info["uid"], near=torch.Tensor([cam_info["near"]]).to(data_device), far=torch.Tensor([cam_info["far"]]).to(data_device), data_device=data_device, ).to(data_device) def generate_cam_params_for_wobble(t: Tensor, cam_a: Camera, cam_b: Camera): origin_a = cam_a.extrinsics[:3, 3] origin_b = cam_b.extrinsics[:3, 3] cam_a_extrinsics = cam_a.extrinsics cam_b_extrinsics = cam_b.extrinsics cam_a_intrinsics = cam_a.intrinsics cam_b_intrinsics = cam_b.intrinsics delta = (origin_a - origin_b).norm(dim=-1) tf = generate_wobble_transformation( radius=delta * 0.5, t=t, num_rotations=1, scale_radius_with_t=False, ) extrinsics = interpolate_extrinsics( initial=cam_a_extrinsics, final=cam_b_extrinsics, t=(t - 2), ) intrinsics = interpolate_intrinsics( initial=cam_a_intrinsics, final=cam_b_intrinsics, t=(t - 2), ) return extrinsics @ tf, intrinsics def generate_cam_params_for_interpolation(t: Tensor, cam_a: Camera, cam_b: Camera): cam_a_extrinsics = cam_a.extrinsics cam_a_extrinsics_render_view = cam_a.extrinsics_render_view cam_b_extrinsics = cam_b.extrinsics cam_b_extrinsics_render_view = cam_b.extrinsics_render_view cam_a_intrinsics = cam_a.intrinsics cam_a_intrinsics_render_view = cam_a.intrinsics_render_view cam_b_intrinsics = cam_b.intrinsics cam_b_intrinsics_render_view = cam_b.intrinsics_render_view extrinsics = interpolate_extrinsics( initial=cam_a_extrinsics, final=cam_b_extrinsics, t=(t - 2), ) intrinsics = interpolate_intrinsics( initial=cam_a_intrinsics, final=cam_b_intrinsics, t=(t - 2), ) extrinsics_render_view = interpolate_extrinsics( initial=cam_a_extrinsics_render_view, final=cam_b_extrinsics_render_view, t=(t - 2), ) intrinsics_render_view = interpolate_intrinsics( initial=cam_a_intrinsics_render_view, final=cam_b_intrinsics_render_view, t=(t - 2), ) return extrinsics, intrinsics, extrinsics_render_view, intrinsics_render_view def get_intermediate_cameras(cam_a: Camera, cam_b: Camera, num_frames: int = 150, smooth: bool = False): t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=cam_a.data_device) if smooth: t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 extrinsics, intrinsics, extrinsics_render_view, intrinsics_render_view = ( generate_cam_params_for_interpolation(t, cam_a, cam_b) ) extrinsics = extrinsics.squeeze(0) intrinsics = intrinsics.squeeze(0) extrinsics_render_view = extrinsics_render_view.squeeze(0) intrinsics_render_view = intrinsics_render_view.squeeze(0) cameras = [ Camera( colmap_id=cam_a.colmap_id, image_name=f"{cam_a.image_name}_{index:04d}", uid=index, near=cam_a.znear, far=cam_a.zfar, data_device=cam_a.data_device, image=cam_a.original_image, # These views have no ground truth image but we should never require images for mesh views raw_image_shape=cam_a.raw_image_shape, extrinsics=extrinsics[index], intrinsics=intrinsics[index], extrinsics_render_view=extrinsics_render_view[index], intrinsics_render_view=intrinsics_render_view[index], scale_matrix=cam_a.scale_matrix, trans_matrix=cam_a.trans_matrix, gt_alpha_mask=None ) for index in range(num_frames) ] return cameras def patch_shim(cams: list[Camera], patch_size: int) -> list[Camera]: new_cams = [] for cam in cams: _, h, w = cam.original_image.shape assert h % 2 == 0 and w % 2 == 0 h_new = (h // patch_size) * patch_size row = (h - h_new) // 2 w_new = (w // patch_size) * patch_size col = (w - w_new) // 2 # Center-crop the image. new_original_image = cam.original_image[:, row : row + h_new, col : col + w_new] # Adjust the intrinsics to account for the cropping. new_intrinsics = cam.intrinsics.clone() new_intrinsics[0, 2] -= col new_intrinsics[1, 2] -= row # Adjust the intrinsics to account for the cropping. new_render_view_intrinsics = cam.intrinsics_render_view.clone() new_render_view_intrinsics[0] -= col new_render_view_intrinsics[1] -= row new_cams.append( Camera( colmap_id=cam.colmap_id, image_name=cam.image_name, uid=cam.uid, near=cam.znear, far=cam.zfar, data_device=cam.data_device, raw_image_shape=cam.raw_image_shape, image=new_original_image, extrinsics=cam.extrinsics, intrinsics=new_intrinsics, extrinsics_render_view=cam.extrinsics_render_view, intrinsics_render_view=new_render_view_intrinsics, scale_matrix=cam.scale_matrix, trans_matrix=cam.trans_matrix, gt_alpha_mask=cam.gt_alpha_mask ) ) return new_cams def calculate_cameras_extent(cam_centers: Tensor): avg_cam_center = cam_centers.mean(dim=0, keepdim=True) dist = torch.norm(cam_centers - avg_cam_center, dim=-1, keepdim=True) diagonal = dist.max() center = avg_cam_center.flatten() radius = diagonal * 1.1 translate = -center return translate, radius.item() def save_cameras(cameras: list[Camera], save_dir: Path): os.makedirs(save_dir, exist_ok=True) extrinsics = torch.stack([cam.extrinsics for cam in cameras]) intrinsics = torch.stack([cam.intrinsics for cam in cameras]) images = torch.stack([cam.original_image for cam in cameras]) torch.save(extrinsics, save_dir / "extrinsics.pt") torch.save(intrinsics, save_dir / "intrinsics.pt") torch.save(images, save_dir / "images.pt") if cameras[0].gt_alpha_mask is not None: gt_alpha_masks = torch.stack([cam.gt_alpha_mask for cam in cameras]) torch.save(gt_alpha_masks, save_dir / "gt_alpha_masks.pt") with open(save_dir / "cam_info.json", "w") as f: json.dump( { "num_cameras": len(cameras), "image_shape": [(cam.image_height, cam.image_width) for cam in cameras], "znear": [cam.znear.item() for cam in cameras], "zfar": [cam.zfar.item() for cam in cameras], "uids": [cam.uid for cam in cameras], "colmap_ids": [cam.colmap_id for cam in cameras], "raw_image_shapes": [cam.raw_image_shape for cam in cameras], }, f, indent=4, ) def load_cameras(cam_dir: Path, device: torch.device) -> list[Camera]: cameras = [] extrinsics = torch.load(cam_dir / "extrinsics.pt") intrinsics = torch.load(cam_dir / "intrinsics.pt") images = torch.load(cam_dir / "images.pt") if (cam_dir / "gt_alpha_masks.pt").exists(): gt_alpha_masks = torch.load(cam_dir / "gt_alpha_masks.pt") else: gt_alpha_masks = [None] * len(images) with open(cam_dir / "cam_info.json", "r") as f: cam_info = json.load(f) for idx in range(cam_info["num_cameras"]): cameras.append( Camera( colmap_id=cam_info["colmap_ids"][idx], image_name=f"image_{idx:04d}", uid=cam_info["uids"][idx], near=torch.Tensor([cam_info["znear"][idx]]).to(device), far=torch.Tensor([cam_info["zfar"][idx]]).to(device), data_device=device, image=images[idx].to(device), extrinsics=extrinsics[idx].to(device), intrinsics=intrinsics[idx].to(device), raw_image_shape=tuple(cam_info["raw_image_shapes"][idx]), gt_alpha_mask=gt_alpha_masks[idx].to(device) if gt_alpha_masks[idx] is not None else None ) ) return cameras