Spaces:
Runtime error
Runtime error
| from typing import Protocol, runtime_checkable | |
| import cv2 | |
| import torch | |
| from einops import repeat, pack | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from .camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics | |
| from .camera_trajectory.wobble import generate_wobble, generate_wobble_transformation | |
| from .layout import vcat | |
| from ..dataset.types import BatchedExample | |
| from ..misc.image_io import save_video | |
| from ..misc.utils import vis_depth_map | |
| from ..model.decoder import Decoder | |
| from ..model.types import Gaussians | |
| class TrajectoryFn(Protocol): | |
| def __call__( | |
| self, | |
| t: Float[Tensor, " t"], | |
| ) -> tuple[ | |
| Float[Tensor, "batch view 4 4"], # extrinsics | |
| Float[Tensor, "batch view 3 3"], # intrinsics | |
| ]: | |
| pass | |
| def render_video_wobble( | |
| gaussians: Gaussians, | |
| decoder: Decoder, | |
| batch: BatchedExample, | |
| num_frames: int = 60, | |
| smooth: bool = True, | |
| loop_reverse: bool = True, | |
| add_depth: bool = False, | |
| ) -> Tensor: | |
| # Two views are needed to get the wobble radius,use the first and the last view | |
| _, v, _, _ = batch["context"]["extrinsics"].shape | |
| def trajectory_fn(t): | |
| origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] | |
| origin_b = batch["context"]["extrinsics"][:, -1, :3, 3] | |
| delta = (origin_a - origin_b).norm(dim=-1) | |
| extrinsics = generate_wobble( | |
| batch["context"]["extrinsics"][:, 0], | |
| delta * 0.25, | |
| t, | |
| ) | |
| intrinsics = repeat( | |
| batch["context"]["intrinsics"][:, 0], | |
| "b i j -> b v i j", | |
| v=t.shape[0], | |
| ) | |
| return extrinsics, intrinsics | |
| return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth) | |
| def render_video_interpolation( | |
| gaussians: Gaussians, | |
| decoder: Decoder, | |
| batch: BatchedExample, | |
| num_frames: int = 60, | |
| smooth: bool = True, | |
| loop_reverse: bool = True, | |
| add_depth: bool = False, | |
| ) -> Tensor: | |
| _, v, _, _ = batch["context"]["extrinsics"].shape | |
| def trajectory_fn(t): | |
| extrinsics = interpolate_extrinsics( | |
| batch["context"]["extrinsics"][0, 0], | |
| batch["context"]["extrinsics"][0, -1], | |
| t, | |
| ) | |
| intrinsics = interpolate_intrinsics( | |
| batch["context"]["intrinsics"][0, 0], | |
| batch["context"]["intrinsics"][0, -1], | |
| t, | |
| ) | |
| return extrinsics[None], intrinsics[None] | |
| return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth) | |
| def render_video_interpolation_exaggerated( | |
| gaussians: Gaussians, | |
| decoder: Decoder, | |
| batch: BatchedExample, | |
| num_frames: int = 300, | |
| smooth: bool = False, | |
| loop_reverse: bool = False, | |
| add_depth: bool = False, | |
| ) -> Tensor: | |
| # Two views are needed to get the wobble radius. | |
| _, v, _, _ = batch["context"]["extrinsics"].shape | |
| def trajectory_fn(t): | |
| origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] | |
| origin_b = batch["context"]["extrinsics"][:, -1, :3, 3] | |
| delta = (origin_a - origin_b).norm(dim=-1) | |
| tf = generate_wobble_transformation( | |
| delta * 0.5, | |
| t, | |
| 5, | |
| scale_radius_with_t=False, | |
| ) | |
| extrinsics = interpolate_extrinsics( | |
| batch["context"]["extrinsics"][0, 0], | |
| batch["context"]["extrinsics"][0, -1], | |
| t * 5 - 2, | |
| ) | |
| intrinsics = interpolate_intrinsics( | |
| batch["context"]["intrinsics"][0, 0], | |
| batch["context"]["extrinsics"][0, -1], | |
| t * 5 - 2, | |
| ) | |
| return extrinsics @ tf, intrinsics[None] | |
| return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth) | |
| def render_video_generic( | |
| gaussians: Gaussians, | |
| decoder: Decoder, | |
| batch: BatchedExample, | |
| trajectory_fn: TrajectoryFn, | |
| num_frames: int = 30, | |
| smooth: bool = True, | |
| loop_reverse: bool = True, | |
| add_depth: bool = False, | |
| ) -> Tensor: | |
| device = gaussians.means.device | |
| t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=device) | |
| if smooth: | |
| t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 | |
| extrinsics, intrinsics = trajectory_fn(t) | |
| _, _, _, h, w = batch["context"]["image"].shape | |
| near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames) | |
| far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames) | |
| output = decoder.forward( | |
| gaussians, extrinsics, intrinsics, near, far, (h, w), "depth" | |
| ) | |
| images = [ | |
| vcat(rgb, depth) if add_depth else rgb | |
| for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0])) | |
| ] | |
| video = torch.stack(images) | |
| # video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy() | |
| if loop_reverse: | |
| # video = pack([video, video[::-1][1:-1]], "* c h w")[0] | |
| video = pack([video, video.flip(dims=(0,))[1:-1]], "* c h w")[0] | |
| return video | |