Spaces:
Runtime error
Runtime error
| import argparse | |
| import math | |
| import os | |
| import time | |
| import imageio | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import tqdm | |
| import viser | |
| from pathlib import Path | |
| from gsplat._helper import load_test_data | |
| from gsplat.distributed import cli | |
| from gsplat.rendering import rasterization | |
| from nerfview import CameraState, RenderTabState, apply_float_colormap | |
| from examples.gsplat_viewer import GsplatViewer, GsplatRenderTabState | |
| def main(local_rank: int, world_rank, world_size: int, args): | |
| torch.manual_seed(42) | |
| device = torch.device("cuda", local_rank) | |
| if args.ckpt is None: | |
| ( | |
| means, | |
| quats, | |
| scales, | |
| opacities, | |
| colors, | |
| viewmats, | |
| Ks, | |
| width, | |
| height, | |
| ) = load_test_data(device=device, scene_grid=args.scene_grid) | |
| assert world_size <= 2 | |
| means = means[world_rank::world_size].contiguous() | |
| means.requires_grad = True | |
| quats = quats[world_rank::world_size].contiguous() | |
| quats.requires_grad = True | |
| scales = scales[world_rank::world_size].contiguous() | |
| scales.requires_grad = True | |
| opacities = opacities[world_rank::world_size].contiguous() | |
| opacities.requires_grad = True | |
| colors = colors[world_rank::world_size].contiguous() | |
| colors.requires_grad = True | |
| viewmats = viewmats[world_rank::world_size][:1].contiguous() | |
| Ks = Ks[world_rank::world_size][:1].contiguous() | |
| sh_degree = None | |
| C = len(viewmats) | |
| N = len(means) | |
| print("rank", world_rank, "Number of Gaussians:", N, "Number of Cameras:", C) | |
| # batched render | |
| for _ in tqdm.trange(1): | |
| render_colors, render_alphas, meta = rasterization( | |
| means, # [N, 3] | |
| quats, # [N, 4] | |
| scales, # [N, 3] | |
| opacities, # [N] | |
| colors, # [N, S, 3] | |
| viewmats, # [C, 4, 4] | |
| Ks, # [C, 3, 3] | |
| width, | |
| height, | |
| render_mode="RGB+D", | |
| packed=False, | |
| distributed=world_size > 1, | |
| ) | |
| C = render_colors.shape[0] | |
| assert render_colors.shape == (C, height, width, 4) | |
| assert render_alphas.shape == (C, height, width, 1) | |
| render_colors.sum().backward() | |
| render_rgbs = render_colors[..., 0:3] | |
| render_depths = render_colors[..., 3:4] | |
| render_depths = render_depths / render_depths.max() | |
| # dump batch images | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| canvas = ( | |
| torch.cat( | |
| [ | |
| render_rgbs.reshape(C * height, width, 3), | |
| render_depths.reshape(C * height, width, 1).expand(-1, -1, 3), | |
| render_alphas.reshape(C * height, width, 1).expand(-1, -1, 3), | |
| ], | |
| dim=1, | |
| ) | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| ) | |
| imageio.imsave( | |
| f"{args.output_dir}/render_rank{world_rank}.png", | |
| (canvas * 255).astype(np.uint8), | |
| ) | |
| else: | |
| means, quats, scales, opacities, sh0, shN = [], [], [], [], [], [] | |
| for ckpt_path in args.ckpt: | |
| ckpt = torch.load(ckpt_path, map_location=device)["splats"] | |
| means.append(ckpt["means"]) | |
| quats.append(F.normalize(ckpt["quats"], p=2, dim=-1)) | |
| scales.append(torch.exp(ckpt["scales"])) | |
| opacities.append(torch.sigmoid(ckpt["opacities"])) | |
| sh0.append(ckpt["sh0"]) | |
| shN.append(ckpt["shN"]) | |
| means = torch.cat(means, dim=0) | |
| quats = torch.cat(quats, dim=0) | |
| scales = torch.cat(scales, dim=0) | |
| opacities = torch.cat(opacities, dim=0) | |
| sh0 = torch.cat(sh0, dim=0) | |
| shN = torch.cat(shN, dim=0) | |
| colors = torch.cat([sh0, shN], dim=-2) | |
| sh_degree = int(math.sqrt(colors.shape[-2]) - 1) | |
| # # crop | |
| # aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device) | |
| # edges = aabb[3:] - aabb[:3] | |
| # sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1) | |
| # sel = torch.where(sel)[0] | |
| # means, quats, scales, colors, opacities = ( | |
| # means[sel], | |
| # quats[sel], | |
| # scales[sel], | |
| # colors[sel], | |
| # opacities[sel], | |
| # ) | |
| # # repeat the scene into a grid (to mimic a large-scale setting) | |
| # repeats = args.scene_grid | |
| # gridx, gridy = torch.meshgrid( | |
| # [ | |
| # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), | |
| # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device), | |
| # ], | |
| # indexing="ij", | |
| # ) | |
| # grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape( | |
| # -1, 3 | |
| # ) | |
| # means = means[None, :, :] + grid[:, None, :] * edges[None, None, :] | |
| # means = means.reshape(-1, 3) | |
| # quats = quats.repeat(repeats**2, 1) | |
| # scales = scales.repeat(repeats**2, 1) | |
| # colors = colors.repeat(repeats**2, 1, 1) | |
| # opacities = opacities.repeat(repeats**2) | |
| print("Number of Gaussians:", len(means)) | |
| # register and open viewer | |
| def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState): | |
| assert isinstance(render_tab_state, GsplatRenderTabState) | |
| if render_tab_state.preview_render: | |
| width = render_tab_state.render_width | |
| height = render_tab_state.render_height | |
| else: | |
| width = render_tab_state.viewer_width | |
| height = render_tab_state.viewer_height | |
| c2w = camera_state.c2w | |
| K = camera_state.get_K((width, height)) | |
| c2w = torch.from_numpy(c2w).float().to(device) | |
| K = torch.from_numpy(K).float().to(device) | |
| viewmat = c2w.inverse() | |
| RENDER_MODE_MAP = { | |
| "rgb": "RGB", | |
| "depth(accumulated)": "D", | |
| "depth(expected)": "ED", | |
| "alpha": "RGB", | |
| } | |
| render_colors, render_alphas, info = rasterization( | |
| means, # [N, 3] | |
| quats, # [N, 4] | |
| scales, # [N, 3] | |
| opacities, # [N] | |
| colors, # [N, S, 3] | |
| viewmat[None], # [1, 4, 4] | |
| K[None], # [1, 3, 3] | |
| width, | |
| height, | |
| sh_degree=( | |
| min(render_tab_state.max_sh_degree, sh_degree) | |
| if sh_degree is not None | |
| else None | |
| ), | |
| near_plane=render_tab_state.near_plane, | |
| far_plane=render_tab_state.far_plane, | |
| radius_clip=render_tab_state.radius_clip, | |
| eps2d=render_tab_state.eps2d, | |
| backgrounds=torch.tensor([render_tab_state.backgrounds], device=device) | |
| / 255.0, | |
| render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], | |
| rasterize_mode=render_tab_state.rasterize_mode, | |
| camera_model=render_tab_state.camera_model, | |
| ) | |
| render_tab_state.total_gs_count = len(means) | |
| render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() | |
| if render_tab_state.render_mode == "rgb": | |
| # colors represented with sh are not guranteed to be in [0, 1] | |
| render_colors = render_colors[0, ..., 0:3].clamp(0, 1) | |
| renders = render_colors.cpu().numpy() | |
| elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: | |
| # normalize depth to [0, 1] | |
| depth = render_colors[0, ..., 0:1] | |
| if render_tab_state.normalize_nearfar: | |
| near_plane = render_tab_state.near_plane | |
| far_plane = render_tab_state.far_plane | |
| else: | |
| near_plane = depth.min() | |
| far_plane = depth.max() | |
| depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10) | |
| depth_norm = torch.clip(depth_norm, 0, 1) | |
| if render_tab_state.inverse: | |
| depth_norm = 1 - depth_norm | |
| renders = ( | |
| apply_float_colormap(depth_norm, render_tab_state.colormap) | |
| .cpu() | |
| .numpy() | |
| ) | |
| elif render_tab_state.render_mode == "alpha": | |
| alpha = render_alphas[0, ..., 0:1] | |
| if render_tab_state.inverse: | |
| alpha = 1 - alpha | |
| renders = ( | |
| apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy() | |
| ) | |
| return renders | |
| server = viser.ViserServer(port=args.port, verbose=False) | |
| _ = GsplatViewer( | |
| server=server, | |
| render_fn=viewer_render_fn, | |
| output_dir=Path(args.output_dir), | |
| mode="rendering", | |
| ) | |
| print("Viewer running... Ctrl+C to exit.") | |
| time.sleep(100000) | |
| if __name__ == "__main__": | |
| """ | |
| # Use single GPU to view the scene | |
| CUDA_VISIBLE_DEVICES=9 python -m simple_viewer \ | |
| --ckpt results/garden/ckpts/ckpt_6999_rank0.pt \ | |
| --output_dir results/garden/ \ | |
| --port 8082 | |
| CUDA_VISIBLE_DEVICES=9 python -m simple_viewer \ | |
| --output_dir results/garden/ \ | |
| --port 8082 | |
| """ | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--output_dir", type=str, default="results/", help="where to dump outputs" | |
| ) | |
| parser.add_argument( | |
| "--scene_grid", type=int, default=1, help="repeat the scene into a grid of NxN" | |
| ) | |
| parser.add_argument( | |
| "--ckpt", type=str, nargs="+", default=None, help="path to the .pt file" | |
| ) | |
| parser.add_argument( | |
| "--port", type=int, default=8080, help="port for the viewer server" | |
| ) | |
| args = parser.parse_args() | |
| assert args.scene_grid % 2 == 1, "scene_grid must be odd" | |
| cli(main, args, verbose=True) | |