| from typing import Protocol, runtime_checkable |
|
|
| import cv2 |
| import torch |
| from einops import repeat, pack |
| from jaxtyping import Float |
| from torch import Tensor |
|
|
| from .camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics |
| from .camera_trajectory.wobble import generate_wobble, generate_wobble_transformation |
| from .layout import vcat |
| from ..dataset.types import BatchedExample |
| from ..misc.image_io import save_video |
| from ..misc.utils import vis_depth_map |
| from ..model.decoder import Decoder |
| from ..model.types import Gaussians |
|
|
|
|
| @runtime_checkable |
| class TrajectoryFn(Protocol): |
| def __call__( |
| self, |
| t: Float[Tensor, " t"], |
| ) -> tuple[ |
| Float[Tensor, "batch view 4 4"], |
| Float[Tensor, "batch view 3 3"], |
| ]: |
| pass |
|
|
|
|
| def render_video_wobble( |
| gaussians: Gaussians, |
| decoder: Decoder, |
| batch: BatchedExample, |
| num_frames: int = 60, |
| smooth: bool = True, |
| loop_reverse: bool = True, |
| add_depth: bool = False, |
| ) -> Tensor: |
| |
| _, v, _, _ = batch["context"]["extrinsics"].shape |
|
|
| def trajectory_fn(t): |
| origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] |
| origin_b = batch["context"]["extrinsics"][:, -1, :3, 3] |
| delta = (origin_a - origin_b).norm(dim=-1) |
| extrinsics = generate_wobble( |
| batch["context"]["extrinsics"][:, 0], |
| delta * 0.25, |
| t, |
| ) |
| intrinsics = repeat( |
| batch["context"]["intrinsics"][:, 0], |
| "b i j -> b v i j", |
| v=t.shape[0], |
| ) |
| return extrinsics, intrinsics |
|
|
| return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth) |
|
|
|
|
| def render_video_interpolation( |
| gaussians: Gaussians, |
| decoder: Decoder, |
| batch: BatchedExample, |
| num_frames: int = 60, |
| smooth: bool = True, |
| loop_reverse: bool = True, |
| add_depth: bool = False, |
| ) -> Tensor: |
| _, v, _, _ = batch["context"]["extrinsics"].shape |
|
|
| def trajectory_fn(t): |
| extrinsics = interpolate_extrinsics( |
| batch["context"]["extrinsics"][0, 0], |
| batch["context"]["extrinsics"][0, -1], |
| t, |
| ) |
| intrinsics = interpolate_intrinsics( |
| batch["context"]["intrinsics"][0, 0], |
| batch["context"]["intrinsics"][0, -1], |
| t, |
| ) |
| return extrinsics[None], intrinsics[None] |
|
|
| return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth) |
|
|
|
|
| def render_video_interpolation_exaggerated( |
| gaussians: Gaussians, |
| decoder: Decoder, |
| batch: BatchedExample, |
| num_frames: int = 300, |
| smooth: bool = False, |
| loop_reverse: bool = False, |
| add_depth: bool = False, |
| ) -> Tensor: |
| |
| _, v, _, _ = batch["context"]["extrinsics"].shape |
|
|
| def trajectory_fn(t): |
| origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] |
| origin_b = batch["context"]["extrinsics"][:, -1, :3, 3] |
| delta = (origin_a - origin_b).norm(dim=-1) |
| tf = generate_wobble_transformation( |
| delta * 0.5, |
| t, |
| 5, |
| scale_radius_with_t=False, |
| ) |
| extrinsics = interpolate_extrinsics( |
| batch["context"]["extrinsics"][0, 0], |
| batch["context"]["extrinsics"][0, -1], |
| t * 5 - 2, |
| ) |
| intrinsics = interpolate_intrinsics( |
| batch["context"]["intrinsics"][0, 0], |
| batch["context"]["extrinsics"][0, -1], |
| t * 5 - 2, |
| ) |
| return extrinsics @ tf, intrinsics[None] |
|
|
| return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth) |
|
|
|
|
| def render_video_generic( |
| gaussians: Gaussians, |
| decoder: Decoder, |
| batch: BatchedExample, |
| trajectory_fn: TrajectoryFn, |
| num_frames: int = 30, |
| smooth: bool = True, |
| loop_reverse: bool = True, |
| add_depth: bool = False, |
| ) -> Tensor: |
| device = gaussians.means.device |
|
|
| t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=device) |
| if smooth: |
| t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 |
|
|
| extrinsics, intrinsics = trajectory_fn(t) |
|
|
| _, _, _, h, w = batch["context"]["image"].shape |
|
|
| near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames) |
| far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames) |
| output = decoder.forward( |
| gaussians, extrinsics, intrinsics, near, far, (h, w), "depth" |
| ) |
| images = [ |
| vcat(rgb, depth) if add_depth else rgb |
| for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0])) |
| ] |
|
|
| video = torch.stack(images) |
| |
| if loop_reverse: |
| |
| video = pack([video, video.flip(dims=(0,))[1:-1]], "* c h w")[0] |
|
|
| return video |
|
|