Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from typing import Literal, Optional | |
| import moviepy.editor as mpy | |
| import torch | |
| from depth_anything_3.model.utils.gs_renderer import run_renderer_in_chunk_w_trj_mode | |
| from depth_anything_3.specs import Prediction | |
| from depth_anything_3.utils.gsply_helpers import save_gaussian_ply | |
| from depth_anything_3.utils.layout_helpers import hcat, vcat | |
| from depth_anything_3.utils.visualize import vis_depth_map_tensor | |
| VIDEO_QUALITY_MAP = { | |
| "low": {"crf": "28", "preset": "veryfast"}, | |
| "medium": {"crf": "23", "preset": "medium"}, | |
| "high": {"crf": "18", "preset": "slow"}, | |
| } | |
| def export_to_gs_ply( | |
| prediction: Prediction, | |
| export_dir: str, | |
| gs_views_interval: Optional[ | |
| int | |
| ] = 1, # export GS every N views, useful for extremely dense inputs | |
| ): | |
| gs_world = prediction.gaussians | |
| pred_depth = torch.from_numpy(prediction.depth).unsqueeze(-1).to(gs_world.means) # v h w 1 | |
| idx = 0 | |
| os.makedirs(os.path.join(export_dir, "gs_ply"), exist_ok=True) | |
| save_path = os.path.join(export_dir, f"gs_ply/{idx:04d}.ply") | |
| if gs_views_interval is None: # select around 12 views in total | |
| gs_views_interval = max(pred_depth.shape[0] // 12, 1) | |
| save_gaussian_ply( | |
| gaussians=gs_world, | |
| save_path=save_path, | |
| ctx_depth=pred_depth, | |
| shift_and_scale=False, | |
| save_sh_dc_only=True, | |
| gs_views_interval=gs_views_interval, | |
| inv_opacity=True, | |
| prune_by_depth_percent=0.9, | |
| prune_border_gs=True, | |
| match_3dgs_mcmc_dev=False, | |
| ) | |
| def export_to_gs_video( | |
| prediction: Prediction, | |
| export_dir: str, | |
| extrinsics: Optional[torch.Tensor] = None, # render views' world2cam, "b v 4 4" | |
| intrinsics: Optional[torch.Tensor] = None, # render views' unnormed intrinsics, "b v 3 3" | |
| out_image_hw: Optional[tuple[int, int]] = None, # render views' resolution, (h, w) | |
| chunk_size: Optional[int] = 4, | |
| trj_mode: Literal[ | |
| "original", | |
| "smooth", | |
| "interpolate", | |
| "interpolate_smooth", | |
| "wander", | |
| "dolly_zoom", | |
| "extend", | |
| "wobble_inter", | |
| ] = "extend", | |
| color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+ED", | |
| vis_depth: Optional[Literal["hcat", "vcat"]] = "hcat", | |
| enable_tqdm: Optional[bool] = True, | |
| output_name: Optional[str] = None, | |
| video_quality: Literal["low", "medium", "high"] = "high", | |
| ) -> None: | |
| gs_world = prediction.gaussians | |
| # if target poses are not provided, render the (smooth/interpolate) input poses | |
| if extrinsics is not None: | |
| tgt_extrs = extrinsics | |
| else: | |
| tgt_extrs = torch.from_numpy(prediction.extrinsics).unsqueeze(0).to(gs_world.means) | |
| if prediction.is_metric: | |
| scale_factor = prediction.scale_factor | |
| if scale_factor is not None: | |
| tgt_extrs[:, :, :3, 3] /= scale_factor | |
| tgt_intrs = ( | |
| intrinsics | |
| if intrinsics is not None | |
| else torch.from_numpy(prediction.intrinsics).unsqueeze(0).to(gs_world.means) | |
| ) | |
| # if render resolution is not provided, render the input ones | |
| if out_image_hw is not None: | |
| H, W = out_image_hw | |
| else: | |
| H, W = prediction.depth.shape[-2:] | |
| # if single views, render wander trj | |
| if tgt_extrs.shape[1] <= 1: | |
| trj_mode = "wander" | |
| # trj_mode = "dolly_zoom" | |
| color, depth = run_renderer_in_chunk_w_trj_mode( | |
| gaussians=gs_world, | |
| extrinsics=tgt_extrs, | |
| intrinsics=tgt_intrs, | |
| image_shape=(H, W), | |
| chunk_size=chunk_size, | |
| trj_mode=trj_mode, | |
| use_sh=True, | |
| color_mode=color_mode, | |
| enable_tqdm=enable_tqdm, | |
| ) | |
| # save as video | |
| ffmpeg_params = [ | |
| "-crf", | |
| VIDEO_QUALITY_MAP[video_quality]["crf"], | |
| "-preset", | |
| VIDEO_QUALITY_MAP[video_quality]["preset"], | |
| "-pix_fmt", | |
| "yuv420p", | |
| ] # best compatibility | |
| os.makedirs(os.path.join(export_dir, "gs_video"), exist_ok=True) | |
| for idx in range(color.shape[0]): | |
| video_i = color[idx] | |
| if vis_depth is not None: | |
| depth_i = vis_depth_map_tensor(depth[0]) | |
| cat_fn = hcat if vis_depth == "hcat" else vcat | |
| video_i = torch.stack([cat_fn(c, d) for c, d in zip(video_i, depth_i)]) | |
| frames = list( | |
| (video_i.clamp(0, 1) * 255).byte().permute(0, 2, 3, 1).cpu().numpy() | |
| ) # T x H x W x C, uint8, numpy() | |
| fps = 24 | |
| clip = mpy.ImageSequenceClip(frames, fps=fps) | |
| output_name = f"{idx:04d}_{trj_mode}" if output_name is None else output_name | |
| save_path = os.path.join(export_dir, f"gs_video/{output_name}.mp4") | |
| # clip.write_videofile(save_path, codec="libx264", audio=False, bitrate="4000k") | |
| clip.write_videofile( | |
| save_path, | |
| codec="libx264", | |
| audio=False, | |
| fps=fps, | |
| ffmpeg_params=ffmpeg_params, | |
| ) | |
| return | |