SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from jaxtyping import Float
from torch import Tensor
from pathlib import Path
import os
import json
from optgs.geometry.projection import get_fov, get_projection_matrix
from optgs.visualization.camera_trajectory.wobble import generate_wobble_transformation
from optgs.visualization.camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics
def get_scene_scale(camtoworlds: Float[np.ndarray, "N 4 4"]) -> float:
# camtoworlds: [N, 4, 4]
# size of the scene measured by cameras as in gsplat
camera_locations = camtoworlds[:, :3, 3]
scene_center = np.mean(camera_locations, axis=0)
dists = np.linalg.norm(camera_locations - scene_center, axis=1)
scene_scale = np.max(dists)
return float(scene_scale) * 1.1
class Camera(nn.Module):
"""
A camera class that stores the camera parameters and the image for Re10k dataset.
Attributes:
image_name:
extrinsics: C2W matrix (4x4 torch.Tensor)
intrinsics: K matrix (3x3 torch.Tensor)
near: Near clipping plane distance
far: Far clipping plane distance
image: RGB image (3xHxW torch.Tensor)
fov_x: Field of view in x direction
fov_y: Field of view in y direction
image_heigth: Height of the image
image_width: Width of the image
view_matrix: View matrix (4x4 torch.Tensor)
full_projection_matrix: Full projection matrix (4x4 torch.Tensor)
camera_center: Camera center (3 torch.Tensor)
"""
def __init__(
self,
colmap_id: str,
extrinsics: Float[Tensor, "4 4"],
intrinsics: Float[Tensor, "3 3"],
extrinsics_render_view: Float[Tensor, "4 4"],
intrinsics_render_view: Float[Tensor, "3 3"],
scale_matrix: Float[Tensor, "4 4"],
trans_matrix: Float[Tensor, "4 4"],
image: Float[Tensor, "3 h w"],
raw_image_shape: tuple[int, int],
image_name: str,
uid: int,
near: Float[Tensor, "1"],
far: Float[Tensor, "1"],
data_device: torch.device,
gt_alpha_mask: Float[Tensor, "1 h w"] | None = None,
trans=np.array([0.0, 0.0, 0.0]),
scale=1.0
):
super(Camera, self).__init__()
self.idx = -1
self.uid = uid
self.colmap_id = colmap_id
self.image_name = image_name
try:
self.data_device = data_device
except Exception as e:
print(e)
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
self.extrinsics = extrinsics.to(self.data_device) # C2W matrix! (not really extrinsics)
self.intrinsics = intrinsics.to(self.data_device)
self.extrinsics_render_view = extrinsics_render_view.to(self.data_device)
self.intrinsics_render_view = intrinsics_render_view.to(self.data_device)
self.scale_matrix = scale_matrix.to(self.data_device)
self.trans_matrix = trans_matrix.to(self.data_device)
self.raw_image_shape = raw_image_shape
self.original_image = image.clamp(0.0, 1.0)
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
if gt_alpha_mask is not None:
# self.original_image *= gt_alpha_mask.to(self.data_device)
self.gt_alpha_mask = gt_alpha_mask.to(self.data_device)
else:
# self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
self.gt_alpha_mask = None
self.zfar = far.to(self.data_device)
self.znear = near.to(self.data_device)
self.trans = trans
self.scale = scale
fov_x, fov_y = get_fov(self.intrinsics.unsqueeze(0)).unbind(dim=-1)
self.FoVx = fov_x.item()
self.FoVy = fov_y.item()
projection_matrix = get_projection_matrix(self.znear, self.zfar, fov_x, fov_y)
projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
view_matrix = rearrange(self.extrinsics.inverse(), "i j -> j i")
full_projection = (view_matrix.unsqueeze(0) @ projection_matrix)[0]
self.camera_center = self.extrinsics[:3, 3]
self.projection_matrix = projection_matrix[0].transpose(0, 1)
self.world_view_transform = view_matrix
self.full_proj_transform = full_projection
def save(self, save_dir: Path):
cam_dir = save_dir / self.image_name
os.makedirs(cam_dir, exist_ok=True)
torch.save(self.extrinsics, cam_dir / "extrinsics.pt")
torch.save(self.intrinsics, cam_dir / "intrinsics.pt")
torch.save(self.original_image, cam_dir / "image.pt")
if self.gt_alpha_mask is not None:
torch.save(self.gt_alpha_mask, cam_dir / "gt_alpha_mask.pt")
with open(cam_dir / "cam_info.json", "w") as f:
json.dump(
{
"colmap_id": self.colmap_id,
"image_name": self.image_name,
"uid": self.uid,
"raw_image_shape": self.raw_image_shape,
"near": self.znear.item(),
"far": self.zfar.item()
},
f,
indent=4,
)
@classmethod
def load_camera(cls, cam_dir: Path, data_device: torch.device):
extrinsics = torch.load(cam_dir / "extrinsics.pt")
intrinsics = torch.load(cam_dir / "intrinsics.pt")
image = torch.load(cam_dir / "image.pt")
if (cam_dir / "gt_alpha_mask.pt").exists():
gt_alpha_mask = torch.load(cam_dir / "gt_alpha_mask.pt")
else:
gt_alpha_mask = None
with open(cam_dir / "cam_info.json", "r") as f:
cam_info = json.load(f)
return cls(
colmap_id=cam_info["colmap_id"],
extrinsics=extrinsics.to(data_device),
intrinsics=intrinsics.to(data_device),
image=image.to(data_device),
gt_alpha_mask=gt_alpha_mask.to(data_device) if gt_alpha_mask is not None else None,
raw_image_shape=tuple(cam_info["raw_image_shape"]),
image_name=cam_info["image_name"],
uid=cam_info["uid"],
near=torch.Tensor([cam_info["near"]]).to(data_device),
far=torch.Tensor([cam_info["far"]]).to(data_device),
data_device=data_device,
).to(data_device)
def generate_cam_params_for_wobble(t: Tensor, cam_a: Camera, cam_b: Camera):
origin_a = cam_a.extrinsics[:3, 3]
origin_b = cam_b.extrinsics[:3, 3]
cam_a_extrinsics = cam_a.extrinsics
cam_b_extrinsics = cam_b.extrinsics
cam_a_intrinsics = cam_a.intrinsics
cam_b_intrinsics = cam_b.intrinsics
delta = (origin_a - origin_b).norm(dim=-1)
tf = generate_wobble_transformation(
radius=delta * 0.5,
t=t,
num_rotations=1,
scale_radius_with_t=False,
)
extrinsics = interpolate_extrinsics(
initial=cam_a_extrinsics,
final=cam_b_extrinsics,
t=(t - 2),
)
intrinsics = interpolate_intrinsics(
initial=cam_a_intrinsics,
final=cam_b_intrinsics,
t=(t - 2),
)
return extrinsics @ tf, intrinsics
def generate_cam_params_for_interpolation(t: Tensor, cam_a: Camera, cam_b: Camera):
cam_a_extrinsics = cam_a.extrinsics
cam_a_extrinsics_render_view = cam_a.extrinsics_render_view
cam_b_extrinsics = cam_b.extrinsics
cam_b_extrinsics_render_view = cam_b.extrinsics_render_view
cam_a_intrinsics = cam_a.intrinsics
cam_a_intrinsics_render_view = cam_a.intrinsics_render_view
cam_b_intrinsics = cam_b.intrinsics
cam_b_intrinsics_render_view = cam_b.intrinsics_render_view
extrinsics = interpolate_extrinsics(
initial=cam_a_extrinsics,
final=cam_b_extrinsics,
t=(t - 2),
)
intrinsics = interpolate_intrinsics(
initial=cam_a_intrinsics,
final=cam_b_intrinsics,
t=(t - 2),
)
extrinsics_render_view = interpolate_extrinsics(
initial=cam_a_extrinsics_render_view,
final=cam_b_extrinsics_render_view,
t=(t - 2),
)
intrinsics_render_view = interpolate_intrinsics(
initial=cam_a_intrinsics_render_view,
final=cam_b_intrinsics_render_view,
t=(t - 2),
)
return extrinsics, intrinsics, extrinsics_render_view, intrinsics_render_view
def get_intermediate_cameras(cam_a: Camera, cam_b: Camera, num_frames: int = 150, smooth: bool = False):
t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=cam_a.data_device)
if smooth: t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
extrinsics, intrinsics, extrinsics_render_view, intrinsics_render_view = (
generate_cam_params_for_interpolation(t, cam_a, cam_b)
)
extrinsics = extrinsics.squeeze(0)
intrinsics = intrinsics.squeeze(0)
extrinsics_render_view = extrinsics_render_view.squeeze(0)
intrinsics_render_view = intrinsics_render_view.squeeze(0)
cameras = [
Camera(
colmap_id=cam_a.colmap_id,
image_name=f"{cam_a.image_name}_{index:04d}",
uid=index,
near=cam_a.znear,
far=cam_a.zfar,
data_device=cam_a.data_device,
image=cam_a.original_image, # These views have no ground truth image but we should never require images for mesh views
raw_image_shape=cam_a.raw_image_shape,
extrinsics=extrinsics[index],
intrinsics=intrinsics[index],
extrinsics_render_view=extrinsics_render_view[index],
intrinsics_render_view=intrinsics_render_view[index],
scale_matrix=cam_a.scale_matrix,
trans_matrix=cam_a.trans_matrix,
gt_alpha_mask=None
)
for index in range(num_frames)
]
return cameras
def patch_shim(cams: list[Camera], patch_size: int) -> list[Camera]:
new_cams = []
for cam in cams:
_, h, w = cam.original_image.shape
assert h % 2 == 0 and w % 2 == 0
h_new = (h // patch_size) * patch_size
row = (h - h_new) // 2
w_new = (w // patch_size) * patch_size
col = (w - w_new) // 2
# Center-crop the image.
new_original_image = cam.original_image[:, row : row + h_new, col : col + w_new]
# Adjust the intrinsics to account for the cropping.
new_intrinsics = cam.intrinsics.clone()
new_intrinsics[0, 2] -= col
new_intrinsics[1, 2] -= row
# Adjust the intrinsics to account for the cropping.
new_render_view_intrinsics = cam.intrinsics_render_view.clone()
new_render_view_intrinsics[0] -= col
new_render_view_intrinsics[1] -= row
new_cams.append(
Camera(
colmap_id=cam.colmap_id,
image_name=cam.image_name,
uid=cam.uid,
near=cam.znear,
far=cam.zfar,
data_device=cam.data_device,
raw_image_shape=cam.raw_image_shape,
image=new_original_image,
extrinsics=cam.extrinsics,
intrinsics=new_intrinsics,
extrinsics_render_view=cam.extrinsics_render_view,
intrinsics_render_view=new_render_view_intrinsics,
scale_matrix=cam.scale_matrix,
trans_matrix=cam.trans_matrix,
gt_alpha_mask=cam.gt_alpha_mask
)
)
return new_cams
def calculate_cameras_extent(cam_centers: Tensor):
avg_cam_center = cam_centers.mean(dim=0, keepdim=True)
dist = torch.norm(cam_centers - avg_cam_center, dim=-1, keepdim=True)
diagonal = dist.max()
center = avg_cam_center.flatten()
radius = diagonal * 1.1
translate = -center
return translate, radius.item()
def save_cameras(cameras: list[Camera], save_dir: Path):
os.makedirs(save_dir, exist_ok=True)
extrinsics = torch.stack([cam.extrinsics for cam in cameras])
intrinsics = torch.stack([cam.intrinsics for cam in cameras])
images = torch.stack([cam.original_image for cam in cameras])
torch.save(extrinsics, save_dir / "extrinsics.pt")
torch.save(intrinsics, save_dir / "intrinsics.pt")
torch.save(images, save_dir / "images.pt")
if cameras[0].gt_alpha_mask is not None:
gt_alpha_masks = torch.stack([cam.gt_alpha_mask for cam in cameras])
torch.save(gt_alpha_masks, save_dir / "gt_alpha_masks.pt")
with open(save_dir / "cam_info.json", "w") as f:
json.dump(
{
"num_cameras": len(cameras),
"image_shape": [(cam.image_height, cam.image_width) for cam in cameras],
"znear": [cam.znear.item() for cam in cameras],
"zfar": [cam.zfar.item() for cam in cameras],
"uids": [cam.uid for cam in cameras],
"colmap_ids": [cam.colmap_id for cam in cameras],
"raw_image_shapes": [cam.raw_image_shape for cam in cameras],
},
f,
indent=4,
)
def load_cameras(cam_dir: Path, device: torch.device) -> list[Camera]:
cameras = []
extrinsics = torch.load(cam_dir / "extrinsics.pt")
intrinsics = torch.load(cam_dir / "intrinsics.pt")
images = torch.load(cam_dir / "images.pt")
if (cam_dir / "gt_alpha_masks.pt").exists():
gt_alpha_masks = torch.load(cam_dir / "gt_alpha_masks.pt")
else:
gt_alpha_masks = [None] * len(images)
with open(cam_dir / "cam_info.json", "r") as f:
cam_info = json.load(f)
for idx in range(cam_info["num_cameras"]):
cameras.append(
Camera(
colmap_id=cam_info["colmap_ids"][idx],
image_name=f"image_{idx:04d}",
uid=cam_info["uids"][idx],
near=torch.Tensor([cam_info["znear"][idx]]).to(device),
far=torch.Tensor([cam_info["zfar"][idx]]).to(device),
data_device=device,
image=images[idx].to(device),
extrinsics=extrinsics[idx].to(device),
intrinsics=intrinsics[idx].to(device),
raw_image_shape=tuple(cam_info["raw_image_shapes"][idx]),
gt_alpha_mask=gt_alpha_masks[idx].to(device) if gt_alpha_masks[idx] is not None else None
)
)
return cameras