| | import io |
| | import os |
| | from pathlib import Path |
| | from typing import Union |
| |
|
| | import cv2 |
| | import numpy as np |
| | import skvideo |
| | import torch |
| | import torchvision.transforms as tf |
| | from einops import rearrange, repeat |
| | from jaxtyping import Float, UInt8 |
| |
|
| | from matplotlib import pyplot as plt |
| | from matplotlib.figure import Figure |
| | from PIL import Image |
| | from torch import Tensor |
| |
|
| | FloatImage = Union[ |
| | Float[Tensor, "height width"], |
| | Float[Tensor, "channel height width"], |
| | Float[Tensor, "batch channel height width"], |
| | ] |
| |
|
| |
|
| | def fig_to_image( |
| | fig: Figure, |
| | dpi: int = 100, |
| | device: torch.device = torch.device("cpu"), |
| | ) -> Float[Tensor, "3 height width"]: |
| | buffer = io.BytesIO() |
| | fig.savefig(buffer, format="raw", dpi=dpi) |
| | buffer.seek(0) |
| | data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) |
| | h = int(fig.bbox.bounds[3]) |
| | w = int(fig.bbox.bounds[2]) |
| | data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) |
| | buffer.close() |
| | return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] |
| |
|
| |
|
| | def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: |
| | |
| | if image.ndim == 4: |
| | image = rearrange(image, "b c h w -> c h (b w)") |
| |
|
| | |
| | if image.ndim == 2: |
| | image = rearrange(image, "h w -> () h w") |
| |
|
| | |
| | channel, _, _ = image.shape |
| | if channel == 1: |
| | image = repeat(image, "() h w -> c h w", c=3) |
| | assert image.shape[0] in (3, 4) |
| |
|
| | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) |
| | return rearrange(image, "c h w -> h w c").cpu().numpy() |
| |
|
| |
|
| | def save_image( |
| | image: FloatImage, |
| | path: Union[Path, str], |
| | ) -> None: |
| | """Save an image. Assumed to be in range 0-1.""" |
| |
|
| | |
| | path = Path(path) |
| | path.parent.mkdir(exist_ok=True, parents=True) |
| |
|
| | |
| | Image.fromarray(prep_image(image)).save(path) |
| |
|
| |
|
| | def load_image( |
| | path: Union[Path, str], |
| | ) -> Float[Tensor, "3 height width"]: |
| | return tf.ToTensor()(Image.open(path))[:3] |
| |
|
| |
|
| | def save_video(tensor, save_path, fps=10): |
| | """ |
| | Save a tensor of shape (N, C, H, W) as a video file using imageio. |
| | Args: |
| | tensor: Tensor of shape (N, C, H, W) in range [0, 1] |
| | save_path: Path to save the video file |
| | fps: Frames per second for the video |
| | """ |
| | |
| | video = tensor.cpu().detach().numpy() |
| | video = np.transpose(video, (0, 2, 3, 1)) |
| |
|
| | |
| | video = (video * 255).astype(np.uint8) |
| |
|
| | |
| | import os |
| |
|
| | os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| |
|
| | |
| | import imageio |
| |
|
| | writer = imageio.get_writer(save_path, fps=fps) |
| |
|
| | for frame in video: |
| | writer.append_data(frame) |
| |
|
| | writer.close() |
| |
|
| |
|
| | def save_interpolated_video( |
| | pred_extrinsics, pred_intrinsics, b, h, w, gaussians, save_path, decoder_func, t=10 |
| | ): |
| | |
| | |
| | interpolated_extrinsics = [] |
| | interpolated_intrinsics = [] |
| |
|
| | |
| | for i in range(pred_extrinsics.shape[1] - 1): |
| | |
| | interpolated_extrinsics.append(pred_extrinsics[:, i : i + 1]) |
| | interpolated_intrinsics.append(pred_intrinsics[:, i : i + 1]) |
| |
|
| | |
| | for j in range(1, t + 1): |
| | alpha = j / (t + 1) |
| |
|
| | |
| | start_extrinsic = pred_extrinsics[:, i] |
| | end_extrinsic = pred_extrinsics[:, i + 1] |
| |
|
| | |
| | start_rot = start_extrinsic[:, :3, :3] |
| | end_rot = end_extrinsic[:, :3, :3] |
| | start_trans = start_extrinsic[:, :3, 3] |
| | end_trans = end_extrinsic[:, :3, 3] |
| |
|
| | |
| | interp_trans = (1 - alpha) * start_trans + alpha * end_trans |
| |
|
| | |
| | start_rot_flat = start_rot.reshape(b, 9) |
| | end_rot_flat = end_rot.reshape(b, 9) |
| | interp_rot_flat = (1 - alpha) * start_rot_flat + alpha * end_rot_flat |
| | interp_rot = interp_rot_flat.reshape(b, 3, 3) |
| |
|
| | |
| | u, _, v = torch.svd(interp_rot) |
| | interp_rot = torch.bmm(u, v.transpose(1, 2)) |
| |
|
| | |
| | interp_extrinsic = ( |
| | torch.eye(4, device=pred_extrinsics.device).unsqueeze(0).repeat(b, 1, 1) |
| | ) |
| | interp_extrinsic[:, :3, :3] = interp_rot |
| | interp_extrinsic[:, :3, 3] = interp_trans |
| |
|
| | |
| | start_intrinsic = pred_intrinsics[:, i] |
| | end_intrinsic = pred_intrinsics[:, i + 1] |
| | interp_intrinsic = (1 - alpha) * start_intrinsic + alpha * end_intrinsic |
| |
|
| | |
| | interpolated_extrinsics.append(interp_extrinsic.unsqueeze(1)) |
| | interpolated_intrinsics.append(interp_intrinsic.unsqueeze(1)) |
| |
|
| | |
| | pred_all_extrinsic = torch.cat(interpolated_extrinsics, dim=1) |
| | pred_all_intrinsic = torch.cat(interpolated_intrinsics, dim=1) |
| |
|
| | |
| | interpolated_extrinsics.append(pred_all_extrinsic[:, -1:]) |
| | interpolated_intrinsics.append(pred_all_intrinsic[:, -1:]) |
| |
|
| | |
| | num_frames = pred_all_extrinsic.shape[1] |
| |
|
| | |
| | interpolated_output = decoder_func.forward( |
| | gaussians, |
| | pred_all_extrinsic, |
| | pred_all_intrinsic.float(), |
| | torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 0.1, |
| | torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 100, |
| | (h, w), |
| | ) |
| |
|
| | |
| | video = interpolated_output.color[0].clip(min=0, max=1) |
| | depth = interpolated_output.depth[0] |
| | |
| | |
| | |
| | num_views = pred_extrinsics.shape[1] |
| | depth_norm = (depth - depth[::num_views].quantile(0.01)) / ( |
| | depth[::num_views].quantile(0.99) - depth[::num_views].quantile(0.01) |
| | ) |
| | depth_norm = plt.cm.turbo(depth_norm.cpu().numpy()) |
| | depth_colored = ( |
| | torch.from_numpy(depth_norm[..., :3]).permute(0, 3, 1, 2).to(depth.device) |
| | ) |
| | depth_colored = depth_colored.clip(min=0, max=1) |
| |
|
| | |
| | save_video(depth_colored, os.path.join(save_path, f"depth.mp4")) |
| | |
| | save_video(video, os.path.join(save_path, f"rgb.mp4")) |
| |
|
| | return os.path.join(save_path, f"rgb.mp4"), os.path.join(save_path, f"depth.mp4") |
| |
|