Spaces:
Runtime error
Runtime error
| import io | |
| import os | |
| from pathlib import Path | |
| from typing import Union | |
| import cv2 | |
| import numpy as np | |
| import skvideo | |
| import torch | |
| import torchvision.transforms as tf | |
| from einops import rearrange, repeat | |
| from jaxtyping import Float, UInt8 | |
| from matplotlib import pyplot as plt | |
| from matplotlib.figure import Figure | |
| from PIL import Image | |
| from torch import Tensor | |
| FloatImage = Union[ | |
| Float[Tensor, "height width"], | |
| Float[Tensor, "channel height width"], | |
| Float[Tensor, "batch channel height width"], | |
| ] | |
| def fig_to_image( | |
| fig: Figure, | |
| dpi: int = 100, | |
| device: torch.device = torch.device("cpu"), | |
| ) -> Float[Tensor, "3 height width"]: | |
| buffer = io.BytesIO() | |
| fig.savefig(buffer, format="raw", dpi=dpi) | |
| buffer.seek(0) | |
| data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) | |
| h = int(fig.bbox.bounds[3]) | |
| w = int(fig.bbox.bounds[2]) | |
| data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) | |
| buffer.close() | |
| return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] | |
| def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: | |
| # Handle batched images. | |
| if image.ndim == 4: | |
| image = rearrange(image, "b c h w -> c h (b w)") | |
| # Handle single-channel images. | |
| if image.ndim == 2: | |
| image = rearrange(image, "h w -> () h w") | |
| # Ensure that there are 3 or 4 channels. | |
| channel, _, _ = image.shape | |
| if channel == 1: | |
| image = repeat(image, "() h w -> c h w", c=3) | |
| assert image.shape[0] in (3, 4) | |
| image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) | |
| return rearrange(image, "c h w -> h w c").cpu().numpy() | |
| def save_image( | |
| image: FloatImage, | |
| path: Union[Path, str], | |
| ) -> None: | |
| """Save an image. Assumed to be in range 0-1.""" | |
| # Create the parent directory if it doesn't already exist. | |
| path = Path(path) | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| # Save the image. | |
| Image.fromarray(prep_image(image)).save(path) | |
| def load_image( | |
| path: Union[Path, str], | |
| ) -> Float[Tensor, "3 height width"]: | |
| return tf.ToTensor()(Image.open(path))[:3] | |
| def save_video(tensor, save_path, fps=10): | |
| """ | |
| Save a tensor of shape (N, C, H, W) as a video file using imageio. | |
| Args: | |
| tensor: Tensor of shape (N, C, H, W) in range [0, 1] | |
| save_path: Path to save the video file | |
| fps: Frames per second for the video | |
| """ | |
| # Convert tensor to numpy array and adjust dimensions | |
| video = tensor.cpu().detach().numpy() # (N, C, H, W) | |
| video = np.transpose(video, (0, 2, 3, 1)) # (N, H, W, C) | |
| # Scale to [0, 255] and convert to uint8 | |
| video = (video * 255).astype(np.uint8) | |
| # Ensure the directory exists | |
| import os | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| # Use imageio to write video (handles codec compatibility automatically) | |
| import imageio | |
| writer = imageio.get_writer(save_path, fps=fps) | |
| for frame in video: | |
| writer.append_data(frame) | |
| writer.close() | |
| def save_interpolated_video( | |
| pred_extrinsics, pred_intrinsics, b, h, w, gaussians, save_path, decoder_func, t=10 | |
| ): | |
| # Interpolate between neighboring frames | |
| # t: Number of extra views to interpolate between each pair | |
| interpolated_extrinsics = [] | |
| interpolated_intrinsics = [] | |
| # For each pair of neighboring frame | |
| for i in range(pred_extrinsics.shape[1] - 1): | |
| # Add the current frame | |
| interpolated_extrinsics.append(pred_extrinsics[:, i : i + 1]) | |
| interpolated_intrinsics.append(pred_intrinsics[:, i : i + 1]) | |
| # Interpolate between current and next frame | |
| for j in range(1, t + 1): | |
| alpha = j / (t + 1) | |
| # Interpolate extrinsics | |
| start_extrinsic = pred_extrinsics[:, i] | |
| end_extrinsic = pred_extrinsics[:, i + 1] | |
| # Separate rotation and translation | |
| start_rot = start_extrinsic[:, :3, :3] | |
| end_rot = end_extrinsic[:, :3, :3] | |
| start_trans = start_extrinsic[:, :3, 3] | |
| end_trans = end_extrinsic[:, :3, 3] | |
| # Interpolate translation (linear) | |
| interp_trans = (1 - alpha) * start_trans + alpha * end_trans | |
| # Interpolate rotation (spherical) | |
| start_rot_flat = start_rot.reshape(b, 9) | |
| end_rot_flat = end_rot.reshape(b, 9) | |
| interp_rot_flat = (1 - alpha) * start_rot_flat + alpha * end_rot_flat | |
| interp_rot = interp_rot_flat.reshape(b, 3, 3) | |
| # Normalize rotation matrix to ensure it's orthogonal | |
| u, _, v = torch.svd(interp_rot) | |
| interp_rot = torch.bmm(u, v.transpose(1, 2)) | |
| # Combine interpolated rotation and translation | |
| interp_extrinsic = ( | |
| torch.eye(4, device=pred_extrinsics.device).unsqueeze(0).repeat(b, 1, 1) | |
| ) | |
| interp_extrinsic[:, :3, :3] = interp_rot | |
| interp_extrinsic[:, :3, 3] = interp_trans | |
| # Interpolate intrinsics (linear) | |
| start_intrinsic = pred_intrinsics[:, i] | |
| end_intrinsic = pred_intrinsics[:, i + 1] | |
| interp_intrinsic = (1 - alpha) * start_intrinsic + alpha * end_intrinsic | |
| # Add interpolated frame | |
| interpolated_extrinsics.append(interp_extrinsic.unsqueeze(1)) | |
| interpolated_intrinsics.append(interp_intrinsic.unsqueeze(1)) | |
| # Concatenate all frames | |
| pred_all_extrinsic = torch.cat(interpolated_extrinsics, dim=1) | |
| pred_all_intrinsic = torch.cat(interpolated_intrinsics, dim=1) | |
| # Add the last frame | |
| interpolated_extrinsics.append(pred_all_extrinsic[:, -1:]) | |
| interpolated_intrinsics.append(pred_all_intrinsic[:, -1:]) | |
| # Update K to reflect the new number of frames | |
| num_frames = pred_all_extrinsic.shape[1] | |
| # Render interpolated views | |
| interpolated_output = decoder_func.forward( | |
| gaussians, | |
| pred_all_extrinsic, | |
| pred_all_intrinsic.float(), | |
| torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 0.1, | |
| torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 100, | |
| (h, w), | |
| ) | |
| # Convert to video format | |
| video = interpolated_output.color[0].clip(min=0, max=1) | |
| depth = interpolated_output.depth[0] | |
| # Normalize depth for visualization | |
| # to avoid `quantile() input tensor is too large` | |
| num_views = pred_extrinsics.shape[1] | |
| depth_norm = (depth - depth[::num_views].quantile(0.01)) / ( | |
| depth[::num_views].quantile(0.99) - depth[::num_views].quantile(0.01) | |
| ) | |
| depth_norm = plt.cm.turbo(depth_norm.cpu().numpy()) | |
| depth_colored = ( | |
| torch.from_numpy(depth_norm[..., :3]).permute(0, 3, 1, 2).to(depth.device) | |
| ) | |
| depth_colored = depth_colored.clip(min=0, max=1) | |
| # Save depth video | |
| save_video(depth_colored, os.path.join(save_path, f"depth.mp4")) | |
| # Save video | |
| save_video(video, os.path.join(save_path, f"rgb.mp4")) | |
| return os.path.join(save_path, f"rgb.mp4"), os.path.join(save_path, f"depth.mp4") | |