| | import argparse |
| | import math |
| | import os |
| | import time |
| |
|
| | import imageio |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import tqdm |
| | import viser |
| | from pathlib import Path |
| | from gsplat._helper import load_test_data |
| | from gsplat.distributed import cli |
| | from gsplat.rendering import rasterization |
| |
|
| | from nerfview import CameraState, RenderTabState, apply_float_colormap |
| | from examples.gsplat_viewer import GsplatViewer, GsplatRenderTabState |
| |
|
| |
|
| | def main(local_rank: int, world_rank, world_size: int, args): |
| | torch.manual_seed(42) |
| | device = torch.device("cuda", local_rank) |
| |
|
| | if args.ckpt is None: |
| | ( |
| | means, |
| | quats, |
| | scales, |
| | opacities, |
| | colors, |
| | viewmats, |
| | Ks, |
| | width, |
| | height, |
| | ) = load_test_data(device=device, scene_grid=args.scene_grid) |
| |
|
| | assert world_size <= 2 |
| | means = means[world_rank::world_size].contiguous() |
| | means.requires_grad = True |
| | quats = quats[world_rank::world_size].contiguous() |
| | quats.requires_grad = True |
| | scales = scales[world_rank::world_size].contiguous() |
| | scales.requires_grad = True |
| | opacities = opacities[world_rank::world_size].contiguous() |
| | opacities.requires_grad = True |
| | colors = colors[world_rank::world_size].contiguous() |
| | colors.requires_grad = True |
| |
|
| | viewmats = viewmats[world_rank::world_size][:1].contiguous() |
| | Ks = Ks[world_rank::world_size][:1].contiguous() |
| |
|
| | sh_degree = None |
| | C = len(viewmats) |
| | N = len(means) |
| | print("rank", world_rank, "Number of Gaussians:", N, "Number of Cameras:", C) |
| |
|
| | |
| | for _ in tqdm.trange(1): |
| | render_colors, render_alphas, meta = rasterization( |
| | means, |
| | quats, |
| | scales, |
| | opacities, |
| | colors, |
| | viewmats, |
| | Ks, |
| | width, |
| | height, |
| | render_mode="RGB+D", |
| | packed=False, |
| | distributed=world_size > 1, |
| | ) |
| | C = render_colors.shape[0] |
| | assert render_colors.shape == (C, height, width, 4) |
| | assert render_alphas.shape == (C, height, width, 1) |
| | render_colors.sum().backward() |
| |
|
| | render_rgbs = render_colors[..., 0:3] |
| | render_depths = render_colors[..., 3:4] |
| | render_depths = render_depths / render_depths.max() |
| |
|
| | |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | canvas = ( |
| | torch.cat( |
| | [ |
| | render_rgbs.reshape(C * height, width, 3), |
| | render_depths.reshape(C * height, width, 1).expand(-1, -1, 3), |
| | render_alphas.reshape(C * height, width, 1).expand(-1, -1, 3), |
| | ], |
| | dim=1, |
| | ) |
| | .detach() |
| | .cpu() |
| | .numpy() |
| | ) |
| | imageio.imsave( |
| | f"{args.output_dir}/render_rank{world_rank}.png", |
| | (canvas * 255).astype(np.uint8), |
| | ) |
| | else: |
| | means, quats, scales, opacities, sh0, shN = [], [], [], [], [], [] |
| | for ckpt_path in args.ckpt: |
| | ckpt = torch.load(ckpt_path, map_location=device)["splats"] |
| | means.append(ckpt["means"]) |
| | quats.append(F.normalize(ckpt["quats"], p=2, dim=-1)) |
| | scales.append(torch.exp(ckpt["scales"])) |
| | opacities.append(torch.sigmoid(ckpt["opacities"])) |
| | sh0.append(ckpt["sh0"]) |
| | shN.append(ckpt["shN"]) |
| | means = torch.cat(means, dim=0) |
| | quats = torch.cat(quats, dim=0) |
| | scales = torch.cat(scales, dim=0) |
| | opacities = torch.cat(opacities, dim=0) |
| | sh0 = torch.cat(sh0, dim=0) |
| | shN = torch.cat(shN, dim=0) |
| | colors = torch.cat([sh0, shN], dim=-2) |
| | sh_degree = int(math.sqrt(colors.shape[-2]) - 1) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | print("Number of Gaussians:", len(means)) |
| |
|
| | |
| | @torch.no_grad() |
| | def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState): |
| | assert isinstance(render_tab_state, GsplatRenderTabState) |
| | if render_tab_state.preview_render: |
| | width = render_tab_state.render_width |
| | height = render_tab_state.render_height |
| | else: |
| | width = render_tab_state.viewer_width |
| | height = render_tab_state.viewer_height |
| | c2w = camera_state.c2w |
| | K = camera_state.get_K((width, height)) |
| | c2w = torch.from_numpy(c2w).float().to(device) |
| | K = torch.from_numpy(K).float().to(device) |
| | viewmat = c2w.inverse() |
| |
|
| | RENDER_MODE_MAP = { |
| | "rgb": "RGB", |
| | "depth(accumulated)": "D", |
| | "depth(expected)": "ED", |
| | "alpha": "RGB", |
| | } |
| |
|
| | render_colors, render_alphas, info = rasterization( |
| | means, |
| | quats, |
| | scales, |
| | opacities, |
| | colors, |
| | viewmat[None], |
| | K[None], |
| | width, |
| | height, |
| | sh_degree=( |
| | min(render_tab_state.max_sh_degree, sh_degree) |
| | if sh_degree is not None |
| | else None |
| | ), |
| | near_plane=render_tab_state.near_plane, |
| | far_plane=render_tab_state.far_plane, |
| | radius_clip=render_tab_state.radius_clip, |
| | eps2d=render_tab_state.eps2d, |
| | backgrounds=torch.tensor([render_tab_state.backgrounds], device=device) |
| | / 255.0, |
| | render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], |
| | rasterize_mode=render_tab_state.rasterize_mode, |
| | camera_model=render_tab_state.camera_model, |
| | ) |
| | render_tab_state.total_gs_count = len(means) |
| | render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() |
| |
|
| | if render_tab_state.render_mode == "rgb": |
| | |
| | render_colors = render_colors[0, ..., 0:3].clamp(0, 1) |
| | renders = render_colors.cpu().numpy() |
| | elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: |
| | |
| | depth = render_colors[0, ..., 0:1] |
| | if render_tab_state.normalize_nearfar: |
| | near_plane = render_tab_state.near_plane |
| | far_plane = render_tab_state.far_plane |
| | else: |
| | near_plane = depth.min() |
| | far_plane = depth.max() |
| | depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10) |
| | depth_norm = torch.clip(depth_norm, 0, 1) |
| | if render_tab_state.inverse: |
| | depth_norm = 1 - depth_norm |
| | renders = ( |
| | apply_float_colormap(depth_norm, render_tab_state.colormap) |
| | .cpu() |
| | .numpy() |
| | ) |
| | elif render_tab_state.render_mode == "alpha": |
| | alpha = render_alphas[0, ..., 0:1] |
| | if render_tab_state.inverse: |
| | alpha = 1 - alpha |
| | renders = ( |
| | apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy() |
| | ) |
| | return renders |
| |
|
| | server = viser.ViserServer(port=args.port, verbose=False) |
| | _ = GsplatViewer( |
| | server=server, |
| | render_fn=viewer_render_fn, |
| | output_dir=Path(args.output_dir), |
| | mode="rendering", |
| | ) |
| | print("Viewer running... Ctrl+C to exit.") |
| | time.sleep(100000) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | """ |
| | # Use single GPU to view the scene |
| | CUDA_VISIBLE_DEVICES=9 python -m simple_viewer \ |
| | --ckpt results/garden/ckpts/ckpt_6999_rank0.pt \ |
| | --output_dir results/garden/ \ |
| | --port 8082 |
| | |
| | CUDA_VISIBLE_DEVICES=9 python -m simple_viewer \ |
| | --output_dir results/garden/ \ |
| | --port 8082 |
| | """ |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--output_dir", type=str, default="results/", help="where to dump outputs" |
| | ) |
| | parser.add_argument( |
| | "--scene_grid", type=int, default=1, help="repeat the scene into a grid of NxN" |
| | ) |
| | parser.add_argument( |
| | "--ckpt", type=str, nargs="+", default=None, help="path to the .pt file" |
| | ) |
| | parser.add_argument( |
| | "--port", type=int, default=8080, help="port for the viewer server" |
| | ) |
| | args = parser.parse_args() |
| | assert args.scene_grid % 2 == 1, "scene_grid must be odd" |
| |
|
| | cli(main, args, verbose=True) |
| |
|