| | import json |
| | import math |
| | import os |
| | import time |
| | from collections import defaultdict |
| | from dataclasses import dataclass, field |
| | from pathlib import Path |
| | from typing import Dict, List, Optional, Tuple, Union |
| |
|
| | import imageio |
| | import matplotlib |
| | import torchvision |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import tqdm |
| | import tyro |
| | import viser |
| | import yaml |
| | import torchvision |
| | import sys |
| | from plyfile import PlyData, PlyElement |
| |
|
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
| |
|
| | from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
| | from src.model.types import Gaussians |
| | from src.post_opt.datasets.colmap import Dataset, Parser |
| | from src.post_opt.datasets.traj import ( |
| | generate_ellipse_path_z, |
| | generate_interpolated_path, |
| | generate_spiral_path, |
| | ) |
| | from fused_ssim import fused_ssim |
| |
|
| | from src.utils.image import process_image |
| | from src.post_opt.exporter import export_splats |
| | from src.post_opt.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss |
| | from torch import Tensor |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| | from torch.utils.tensorboard import SummaryWriter |
| | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
| | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
| | from typing_extensions import Literal, assert_never |
| | from src.post_opt.utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed |
| |
|
| | |
| | from gsplat.compression import PngCompression |
| | from gsplat.distributed import cli |
| | |
| | |
| | from gsplat import rasterization |
| | from gsplat.strategy import DefaultStrategy, MCMCStrategy |
| | from src.post_opt.gsplat_viewer import GsplatViewer, GsplatRenderTabState |
| | from nerfview import CameraState, RenderTabState, apply_float_colormap |
| |
|
| | import torch |
| | from einops import rearrange |
| | from jaxtyping import Float |
| | from torch import Tensor |
| | from scipy.spatial.transform import Rotation as R |
| |
|
| | from src.model.model.anysplat import AnySplat |
| |
|
| |
|
| | |
| | def quaternion_to_matrix( |
| | quaternions: Float[Tensor, "*batch 4"], |
| | eps: float = 1e-8, |
| | ) -> Float[Tensor, "*batch 3 3"]: |
| | |
| | i, j, k, r = torch.unbind(quaternions, dim=-1) |
| | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) |
| |
|
| | o = torch.stack( |
| | ( |
| | 1 - two_s * (j * j + k * k), |
| | two_s * (i * j - k * r), |
| | two_s * (i * k + j * r), |
| | two_s * (i * j + k * r), |
| | 1 - two_s * (i * i + k * k), |
| | two_s * (j * k - i * r), |
| | two_s * (i * k - j * r), |
| | two_s * (j * k + i * r), |
| | 1 - two_s * (i * i + j * j), |
| | ), |
| | -1, |
| | ) |
| | return rearrange(o, "... (i j) -> ... i j", i=3, j=3) |
| |
|
| | def construct_list_of_attributes(num_rest: int) -> list[str]: |
| | attributes = ["x", "y", "z", "nx", "ny", "nz"] |
| | for i in range(3): |
| | attributes.append(f"f_dc_{i}") |
| | for i in range(num_rest): |
| | attributes.append(f"f_rest_{i}") |
| | attributes.append("opacity") |
| | for i in range(3): |
| | attributes.append(f"scale_{i}") |
| | for i in range(4): |
| | attributes.append(f"rot_{i}") |
| | return attributes |
| |
|
| | def export_ply( |
| | means: Float[Tensor, "gaussian 3"], |
| | scales: Float[Tensor, "gaussian 3"], |
| | rotations: Float[Tensor, "gaussian 4"], |
| | harmonics: Float[Tensor, "gaussian 3 d_sh"], |
| | opacities: Float[Tensor, " gaussian"], |
| | path: Path, |
| | shift_and_scale: bool = False, |
| | save_sh_dc_only: bool = True, |
| | ): |
| | if shift_and_scale: |
| | |
| | means = means - means.median(dim=0).values |
| |
|
| | |
| | scale_factor = means.abs().quantile(0.95, dim=0).max() |
| | means = means / scale_factor |
| | scales = scales / scale_factor |
| |
|
| | |
| | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() |
| | rotations = R.from_matrix(rotations).as_quat() |
| | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") |
| | rotations = np.stack((w, x, y, z), axis=-1) |
| |
|
| | |
| | |
| | f_dc = harmonics[..., 0] |
| | f_rest = harmonics[..., 1:].flatten(start_dim=1) |
| |
|
| | dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])] |
| | elements = np.empty(means.shape[0], dtype=dtype_full) |
| | attributes = [ |
| | means.detach().cpu().numpy(), |
| | torch.zeros_like(means).detach().cpu().numpy(), |
| | f_dc.detach().cpu().contiguous().numpy(), |
| | f_rest.detach().cpu().contiguous().numpy(), |
| | opacities[..., None].detach().cpu().numpy(), |
| | scales.detach().cpu().numpy(), |
| | rotations, |
| | ] |
| | if save_sh_dc_only: |
| | |
| | attributes.pop(3) |
| |
|
| | attributes = np.concatenate(attributes, axis=1) |
| | elements[:] = list(map(tuple, attributes)) |
| | path.parent.mkdir(exist_ok=True, parents=True) |
| | PlyData([PlyElement.describe(elements, "vertex")]).write(path) |
| |
|
| | def colorize_depth_maps(depth_map, min_depth=0.0, max_depth=1.0, cmap="Spectral", valid_mask=None): |
| | """ |
| | Colorize depth maps. |
| | """ |
| | assert len(depth_map.shape) >= 2, "Invalid dimension" |
| |
|
| | if isinstance(depth_map, torch.Tensor): |
| | depth = depth_map.detach().clone().squeeze().numpy() |
| | elif isinstance(depth_map, np.ndarray): |
| | depth = depth_map.copy().squeeze() |
| | |
| | if depth.ndim < 3: |
| | depth = depth[np.newaxis, :, :] |
| | |
| | |
| | cm = matplotlib.colormaps[cmap] |
| | |
| | depth = ((depth - depth.min()) / (depth.max() - depth.min())).clip(0, 1) |
| | img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] |
| | img_colored_np = np.rollaxis(img_colored_np, 3, 1) |
| |
|
| | if valid_mask is not None: |
| | if isinstance(depth_map, torch.Tensor): |
| | valid_mask = valid_mask.detach().numpy() |
| | valid_mask = valid_mask.squeeze() |
| | if valid_mask.ndim < 3: |
| | valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] |
| | else: |
| | valid_mask = valid_mask[:, np.newaxis, :, :] |
| | valid_mask = np.repeat(valid_mask, 3, axis=1) |
| | img_colored_np[~valid_mask] = 0 |
| |
|
| | if isinstance(depth_map, torch.Tensor): |
| | img_colored = torch.from_numpy(img_colored_np).float() |
| | elif isinstance(depth_map, np.ndarray): |
| | img_colored = img_colored_np |
| |
|
| | return img_colored |
| |
|
| | def build_covariance( |
| | scale: Float[Tensor, "*#batch 3"], |
| | rotation_xyzw: Float[Tensor, "*#batch 4"], |
| | ) -> Float[Tensor, "*batch 3 3"]: |
| | scale = scale.diag_embed() |
| | rotation = quaternion_to_matrix(rotation_xyzw) |
| | return ( |
| | rotation |
| | @ scale |
| | @ rearrange(scale, "... i j -> ... j i") |
| | @ rearrange(rotation, "... i j -> ... j i") |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class Config: |
| | |
| | disable_viewer: bool = True |
| | |
| | ckpt: Optional[List[str]] = None |
| | |
| | compression: Optional[Literal["png"]] = None |
| | |
| | render_traj_path: str = "interp" |
| |
|
| | data_dir: str = "data/360_v2/garden" |
| | |
| | data_factor: int = 4 |
| | |
| | result_dir: str = "results/garden" |
| | |
| | test_every: int = 8 |
| | |
| | patch_size: Optional[int] = None |
| | |
| | global_scale: float = 1.0 |
| | |
| | normalize_world_space: bool = True |
| | |
| | camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" |
| |
|
| | |
| | port: int = 8080 |
| |
|
| | |
| | batch_size: int = 1 |
| | |
| | steps_scaler: float = 1.0 |
| |
|
| | |
| | max_steps: int = 3_000 |
| | |
| | eval_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) |
| | |
| | save_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) |
| | |
| | save_ply: bool = False |
| | |
| | ply_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) |
| | |
| | disable_video: bool = False |
| | |
| | |
| | init_type: str = "sfm" |
| | |
| | init_num_pts: int = 100_000 |
| | |
| | init_extent: float = 3.0 |
| | |
| | sh_degree: int = 4 |
| | |
| | sh_degree_interval: int = 1000 |
| | |
| | init_opa: float = 0.1 |
| | |
| | init_scale: float = 1.0 |
| | |
| | ssim_lambda: float = 0.2 |
| |
|
| | |
| | near_plane: float = 1e-10 |
| | |
| | far_plane: float = 1e10 |
| |
|
| | |
| | strategy: Union[DefaultStrategy, MCMCStrategy] = field( |
| | default_factory=DefaultStrategy |
| | ) |
| | |
| | packed: bool = False |
| | |
| | sparse_grad: bool = False |
| | |
| | visible_adam: bool = False |
| | |
| | antialiased: bool = False |
| |
|
| | |
| | random_bkgd: bool = False |
| |
|
| | |
| | opacity_reg: float = 0.0 |
| | |
| | scale_reg: float = 0.0 |
| |
|
| | |
| | pose_opt: bool = True |
| | |
| | pose_opt_lr: float = 1e-5 |
| | |
| | pose_opt_reg: float = 1e-6 |
| | |
| | pose_noise: float = 0.0 |
| |
|
| | |
| | app_opt: bool = False |
| | |
| | app_embed_dim: int = 16 |
| | |
| | app_opt_lr: float = 1e-3 |
| | |
| | app_opt_reg: float = 1e-6 |
| |
|
| | |
| | use_bilateral_grid: bool = False |
| | |
| | bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) |
| |
|
| | |
| | depth_loss: bool = False |
| | |
| | depth_lambda: float = 1e-2 |
| |
|
| | |
| | tb_every: int = 100 |
| | |
| | tb_save_image: bool = False |
| |
|
| | lpips_net: Literal["vgg", "alex"] = "vgg" |
| |
|
| | lr_means: float = 1.6e-4 |
| | lr_scales: float = 5e-3 |
| | lr_quats: float = 1e-3 |
| | lr_opacities: float = 5e-2 |
| | lr_sh: float = 2.5e-3 |
| |
|
| | def adjust_steps(self, factor: float): |
| | self.eval_steps = [int(i * factor) for i in self.eval_steps] |
| | self.save_steps = [int(i * factor) for i in self.save_steps] |
| | self.ply_steps = [int(i * factor) for i in self.ply_steps] |
| | self.max_steps = int(self.max_steps * factor) |
| | self.sh_degree_interval = int(self.sh_degree_interval * factor) |
| |
|
| | strategy = self.strategy |
| | if isinstance(strategy, DefaultStrategy): |
| | |
| | |
| | |
| | |
| |
|
| | strategy.refine_start_iter = 30000 |
| | strategy.refine_stop_iter = 0 |
| | strategy.reset_every = 30000 |
| | strategy.refine_every = 30000 |
| |
|
| | elif isinstance(strategy, MCMCStrategy): |
| | strategy.refine_start_iter = int(strategy.refine_start_iter * factor) |
| | strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) |
| | strategy.refine_every = int(strategy.refine_every * factor) |
| | else: |
| | assert_never(strategy) |
| |
|
| |
|
| | def create_splats_with_optimizers( |
| | gaussians: Gaussians, |
| | init_num_pts: int = 100_000, |
| | init_extent: float = 3.0, |
| | init_opacity: float = 0.1, |
| | init_scale: float = 1.0, |
| | sh_degree: int = 3, |
| | sparse_grad: bool = False, |
| | visible_adam: bool = False, |
| | batch_size: int = 1, |
| | feature_dim: Optional[int] = None, |
| | device: str = "cuda", |
| | world_rank: int = 0, |
| | world_size: int = 1, |
| | cfg: Config = None, |
| | ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: |
| |
|
| | points = gaussians.means[0].detach().float() |
| | scales = torch.log(gaussians.scales[0].detach().float()) |
| | quats = gaussians.rotations[0].detach().float() |
| | opacities = torch.logit(gaussians.opacities[0].detach().float()) |
| | harmonics = gaussians.harmonics[0].detach().float().permute(0, 2, 1).contiguous() |
| |
|
| | N = points.shape[0] |
| | |
| | scene_scale = 1.0 |
| | masks = opacities.sigmoid() > 0.01 |
| | harmonics = harmonics[masks] |
| | params = [ |
| | |
| | ("means", torch.nn.Parameter(points[masks]), cfg.lr_means * scene_scale), |
| | ("scales", torch.nn.Parameter(scales[masks]), cfg.lr_scales), |
| | ("quats", torch.nn.Parameter(quats[masks]), cfg.lr_quats), |
| | ("opacities", torch.nn.Parameter(opacities[masks]), cfg.lr_opacities), |
| | ] |
| | |
| | params.append(("sh0", torch.nn.Parameter(harmonics[:, :1, :]), cfg.lr_sh)) |
| | params.append(("shN", torch.nn.Parameter(harmonics[:, 1:, :]), cfg.lr_sh/20)) |
| |
|
| | splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) |
| | |
| | |
| | |
| | |
| | BS = batch_size * world_size |
| | optimizer_class = None |
| | if sparse_grad: |
| | optimizer_class = torch.optim.SparseAdam |
| | elif visible_adam: |
| | optimizer_class = SelectiveAdam |
| | else: |
| | optimizer_class = torch.optim.Adam |
| | optimizers = { |
| | name: optimizer_class( |
| | [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], |
| | eps=1e-15 / math.sqrt(BS), |
| | |
| | betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), |
| | ) |
| | for name, _, lr in params |
| | } |
| | return splats, optimizers |
| |
|
| |
|
| | class Runner: |
| | """Engine for training and testing.""" |
| |
|
| | def __init__( |
| | self, local_rank: int, world_rank, world_size: int, cfg: Config |
| | ) -> None: |
| | set_random_seed(42 + local_rank) |
| |
|
| | self.cfg = cfg |
| | self.world_rank = world_rank |
| | self.local_rank = local_rank |
| | self.world_size = world_size |
| | self.device = f"cuda:{local_rank}" |
| |
|
| | |
| | os.makedirs(cfg.result_dir, exist_ok=True) |
| |
|
| | |
| | self.ckpt_dir = f"{cfg.result_dir}/ckpts" |
| | os.makedirs(self.ckpt_dir, exist_ok=True) |
| | self.stats_dir = f"{cfg.result_dir}/stats" |
| | os.makedirs(self.stats_dir, exist_ok=True) |
| | self.render_dir = f"{cfg.result_dir}/renders" |
| | os.makedirs(self.render_dir, exist_ok=True) |
| | self.ply_dir = f"{cfg.result_dir}/ply" |
| | os.makedirs(self.ply_dir, exist_ok=True) |
| |
|
| | |
| | self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") |
| | |
| | |
| | model = AnySplat.from_pretrained("lhjiang/anysplat") |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model.to(device) |
| | model.eval() |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| | |
| | image_folder = cfg.data_dir |
| | image_names = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) |
| | images = [process_image(img_path) for img_path in image_names] |
| | ctx_indices = [idx for idx, name in enumerate(image_names) if idx % cfg.test_every != 0] |
| | tgt_indices = [idx for idx, name in enumerate(image_names) if idx % cfg.test_every == 0] |
| | |
| | ctx_images = torch.stack([images[i] for i in ctx_indices], dim=0).unsqueeze(0).to(device) |
| | tgt_images = torch.stack([images[i] for i in tgt_indices], dim=0).unsqueeze(0).to(device) |
| | ctx_images = (ctx_images+1)*0.5 |
| | tgt_images = (tgt_images+1)*0.5 |
| | b, v, _, h, w = tgt_images.shape |
| |
|
| | |
| | encoder_output = model.encoder( |
| | ctx_images, |
| | global_step=0, |
| | visualization_dump={}, |
| | ) |
| | gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose |
| | |
| | num_context_view = ctx_images.shape[1] |
| | vggt_input_image = torch.cat((ctx_images, tgt_images), dim=1).to(torch.bfloat16) |
| | with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): |
| | aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(vggt_input_image, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
| | with torch.cuda.amp.autocast(enabled=False): |
| | fp32_tokens = [token.float() for token in aggregated_tokens_list] |
| | pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
| | pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, vggt_input_image.shape[-2:]) |
| |
|
| | extrinsic_padding = torch.tensor([0, 0, 0, 1], device=pred_all_extrinsic.device, dtype=pred_all_extrinsic.dtype).view(1, 1, 1, 4).repeat(b, vggt_input_image.shape[1], 1, 1) |
| | pred_all_extrinsic = torch.cat([pred_all_extrinsic, extrinsic_padding], dim=2).inverse() |
| |
|
| | pred_all_intrinsic[:, :, 0] = pred_all_intrinsic[:, :, 0] / w |
| | pred_all_intrinsic[:, :, 1] = pred_all_intrinsic[:, :, 1] / h |
| | pred_all_context_extrinsic, pred_all_target_extrinsic = pred_all_extrinsic[:, :num_context_view], pred_all_extrinsic[:, num_context_view:] |
| | pred_all_context_intrinsic, pred_all_target_intrinsic = pred_all_intrinsic[:, :num_context_view], pred_all_intrinsic[:, num_context_view:] |
| |
|
| | scale_factor = pred_context_pose['extrinsic'][:, :, :3, 3].mean() / pred_all_context_extrinsic[:, :, :3, 3].mean() |
| | pred_all_target_extrinsic[..., :3, 3] = pred_all_target_extrinsic[..., :3, 3] * scale_factor |
| | pred_all_context_extrinsic[..., :3, 3] = pred_all_context_extrinsic[..., :3, 3] * scale_factor |
| | print("scale_factor:", scale_factor) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.trainset = Dataset( |
| | split="train", |
| | images=ctx_images[0].detach().cpu().numpy(), |
| | camtoworlds=pred_all_context_extrinsic[0].detach().cpu().numpy(), |
| | Ks=pred_all_context_intrinsic[0].detach().cpu().numpy(), |
| | patch_size=cfg.patch_size, |
| | load_depths=cfg.depth_loss, |
| | ) |
| | self.valset = Dataset( |
| | images=tgt_images[0].detach().cpu().numpy(), |
| | camtoworlds=pred_all_target_extrinsic[0].detach().cpu().numpy(), |
| | Ks=pred_all_target_intrinsic[0].detach().cpu().numpy(), |
| | split="val" |
| | ) |
| |
|
| | |
| | feature_dim = 32 if cfg.app_opt else None |
| | self.splats, self.optimizers = create_splats_with_optimizers( |
| | gaussians=gaussians, |
| | init_num_pts=cfg.init_num_pts, |
| | init_extent=cfg.init_extent, |
| | init_opacity=cfg.init_opa, |
| | init_scale=cfg.init_scale, |
| | sh_degree=cfg.sh_degree, |
| | sparse_grad=cfg.sparse_grad, |
| | visible_adam=cfg.visible_adam, |
| | batch_size=cfg.batch_size, |
| | feature_dim=feature_dim, |
| | device=self.device, |
| | world_rank=world_rank, |
| | world_size=world_size, |
| | cfg=cfg, |
| | ) |
| | print("Model initialized. Number of GS:", len(self.splats["means"])) |
| |
|
| | |
| | self.cfg.strategy.check_sanity(self.splats, self.optimizers) |
| |
|
| | if isinstance(self.cfg.strategy, DefaultStrategy): |
| | self.strategy_state = self.cfg.strategy.initialize_state( |
| | scene_scale=1.0 |
| | ) |
| | elif isinstance(self.cfg.strategy, MCMCStrategy): |
| | self.strategy_state = self.cfg.strategy.initialize_state() |
| | else: |
| | assert_never(self.cfg.strategy) |
| |
|
| | |
| | self.compression_method = None |
| | if cfg.compression is not None: |
| | if cfg.compression == "png": |
| | self.compression_method = PngCompression() |
| | else: |
| | raise ValueError(f"Unknown compression strategy: {cfg.compression}") |
| |
|
| | self.pose_optimizers = [] |
| | if cfg.pose_opt: |
| | self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) |
| | self.pose_adjust.zero_init() |
| | self.pose_optimizers = [ |
| | torch.optim.Adam( |
| | self.pose_adjust.parameters(), |
| | lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), |
| | weight_decay=cfg.pose_opt_reg, |
| | ) |
| | ] |
| | if world_size > 1: |
| | self.pose_adjust = DDP(self.pose_adjust) |
| |
|
| | if cfg.pose_noise > 0.0: |
| | self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) |
| | self.pose_perturb.random_init(cfg.pose_noise) |
| | if world_size > 1: |
| | self.pose_perturb = DDP(self.pose_perturb) |
| |
|
| | self.app_optimizers = [] |
| | if cfg.app_opt: |
| | assert feature_dim is not None |
| | self.app_module = AppearanceOptModule( |
| | len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree |
| | ).to(self.device) |
| | |
| | torch.nn.init.zeros_(self.app_module.color_head[-1].weight) |
| | torch.nn.init.zeros_(self.app_module.color_head[-1].bias) |
| | self.app_optimizers = [ |
| | torch.optim.Adam( |
| | self.app_module.embeds.parameters(), |
| | lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, |
| | weight_decay=cfg.app_opt_reg, |
| | ), |
| | torch.optim.Adam( |
| | self.app_module.color_head.parameters(), |
| | lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), |
| | ), |
| | ] |
| | if world_size > 1: |
| | self.app_module = DDP(self.app_module) |
| |
|
| | self.bil_grid_optimizers = [] |
| | if cfg.use_bilateral_grid: |
| | self.bil_grids = BilateralGrid( |
| | len(self.trainset), |
| | grid_X=cfg.bilateral_grid_shape[0], |
| | grid_Y=cfg.bilateral_grid_shape[1], |
| | grid_W=cfg.bilateral_grid_shape[2], |
| | ).to(self.device) |
| | self.bil_grid_optimizers = [ |
| | torch.optim.Adam( |
| | self.bil_grids.parameters(), |
| | lr=2e-3 * math.sqrt(cfg.batch_size), |
| | eps=1e-15, |
| | ), |
| | ] |
| |
|
| | |
| | self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) |
| | self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) |
| |
|
| | if cfg.lpips_net == "alex": |
| | self.lpips = LearnedPerceptualImagePatchSimilarity( |
| | net_type="alex", normalize=True |
| | ).to(self.device) |
| | elif cfg.lpips_net == "vgg": |
| | |
| | self.lpips = LearnedPerceptualImagePatchSimilarity( |
| | net_type="vgg", normalize=False |
| | ).to(self.device) |
| | else: |
| | raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") |
| |
|
| | |
| | if not self.cfg.disable_viewer: |
| | self.server = viser.ViserServer(port=cfg.port, verbose=False) |
| | self.viewer = GsplatViewer( |
| | server=self.server, |
| | render_fn=self._viewer_render_fn, |
| | output_dir=Path(cfg.result_dir), |
| | mode="training", |
| | ) |
| | |
| | def rasterize_splats( |
| | self, |
| | camtoworlds: Tensor, |
| | Ks: Tensor, |
| | width: int, |
| | height: int, |
| | masks: Optional[Tensor] = None, |
| | rasterize_mode: Optional[Literal["classic", "antialiased"]] = None, |
| | camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None, |
| | **kwargs, |
| | ) -> Tuple[Tensor, Tensor, Dict]: |
| | means = self.splats["means"] |
| | |
| | |
| | quats = self.splats["quats"] |
| | scales = torch.exp(self.splats["scales"]) |
| | opacities = torch.sigmoid(self.splats["opacities"]) |
| | |
| | image_ids = kwargs.pop("image_ids", None) |
| | if self.cfg.app_opt: |
| | colors = self.app_module( |
| | features=self.splats["features"], |
| | embed_ids=image_ids, |
| | dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], |
| | sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), |
| | ) |
| | colors = colors + self.splats["colors"] |
| | colors = torch.sigmoid(colors) |
| | else: |
| | colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) |
| |
|
| | if rasterize_mode is None: |
| | rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" |
| | if camera_model is None: |
| | camera_model = self.cfg.camera_model |
| | |
| | |
| | render_colors, render_alphas, info = rasterization( |
| | means=means, |
| | quats=quats, |
| | scales=scales, |
| | opacities=opacities, |
| | colors=colors, |
| | |
| | viewmats=torch.linalg.inv(camtoworlds), |
| | Ks=Ks, |
| | width=width, |
| | height=height, |
| | packed=self.cfg.packed, |
| | absgrad=( |
| | self.cfg.strategy.absgrad |
| | if isinstance(self.cfg.strategy, DefaultStrategy) |
| | else False |
| | ), |
| | sparse_grad=self.cfg.sparse_grad, |
| | rasterize_mode=rasterize_mode, |
| | distributed=self.world_size > 1, |
| | camera_model=self.cfg.camera_model, |
| | radius_clip=0.1, |
| | backgrounds=torch.tensor([0.0, 0.0, 0.0]).cuda().unsqueeze(0).repeat(1, 1), |
| | **kwargs, |
| | ) |
| | if masks is not None: |
| | render_colors[~masks] = 0 |
| | return render_colors, render_alphas, info |
| |
|
| | def train(self): |
| | cfg = self.cfg |
| | device = self.device |
| | world_rank = self.world_rank |
| | world_size = self.world_size |
| |
|
| | |
| | if world_rank == 0: |
| | with open(f"{cfg.result_dir}/cfg.yml", "w") as f: |
| | yaml.dump(vars(cfg), f) |
| |
|
| | max_steps = cfg.max_steps |
| | init_step = 0 |
| |
|
| | schedulers = [ |
| | |
| | torch.optim.lr_scheduler.ExponentialLR( |
| | self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) |
| | ), |
| | ] |
| | if cfg.pose_opt: |
| | |
| | schedulers.append( |
| | torch.optim.lr_scheduler.ExponentialLR( |
| | self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) |
| | ) |
| | ) |
| | if cfg.use_bilateral_grid: |
| | |
| | schedulers.append( |
| | torch.optim.lr_scheduler.ChainedScheduler( |
| | [ |
| | torch.optim.lr_scheduler.LinearLR( |
| | self.bil_grid_optimizers[0], |
| | start_factor=0.01, |
| | total_iters=1000, |
| | ), |
| | torch.optim.lr_scheduler.ExponentialLR( |
| | self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps) |
| | ), |
| | ] |
| | ) |
| | ) |
| |
|
| | trainloader = torch.utils.data.DataLoader( |
| | self.trainset, |
| | batch_size=cfg.batch_size, |
| | shuffle=True, |
| | num_workers=4, |
| | persistent_workers=True, |
| | pin_memory=True, |
| | ) |
| | trainloader_iter = iter(trainloader) |
| |
|
| | |
| | global_tic = time.time() |
| | pbar = tqdm.tqdm(range(init_step, max_steps)) |
| | for step in pbar: |
| | if not cfg.disable_viewer: |
| | while self.viewer.state == "paused": |
| | time.sleep(0.01) |
| | self.viewer.lock.acquire() |
| | tic = time.time() |
| |
|
| | try: |
| | data = next(trainloader_iter) |
| | except StopIteration: |
| | trainloader_iter = iter(trainloader) |
| | data = next(trainloader_iter) |
| |
|
| | camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) |
| | Ks = data["K"].to(device) |
| | pixels = data["image"].to(device) / 255.0 |
| | num_train_rays_per_step = ( |
| | pixels.shape[0] * pixels.shape[1] * pixels.shape[2] |
| | ) |
| | image_ids = data["image_id"].to(device) |
| | masks = data["mask"].to(device) if "mask" in data else None |
| | if cfg.depth_loss: |
| | points = data["points"].to(device) |
| | depths_gt = data["depths"].to(device) |
| |
|
| | height, width = pixels.shape[1:3] |
| |
|
| | if cfg.pose_noise: |
| | camtoworlds = self.pose_perturb(camtoworlds, image_ids) |
| |
|
| | if cfg.pose_opt: |
| | camtoworlds = self.pose_adjust(camtoworlds, image_ids) |
| |
|
| | |
| | |
| | sh_degree_to_use = cfg.sh_degree |
| |
|
| | |
| | renders, alphas, info = self.rasterize_splats( |
| | camtoworlds=camtoworlds, |
| | Ks=Ks, |
| | width=width, |
| | height=height, |
| | sh_degree=sh_degree_to_use, |
| | near_plane=cfg.near_plane, |
| | far_plane=cfg.far_plane, |
| | image_ids=image_ids, |
| | render_mode="RGB+ED" if cfg.depth_loss else "RGB", |
| | masks=masks, |
| | ) |
| | if renders.shape[-1] == 4: |
| | colors, depths = renders[..., 0:3], renders[..., 3:4] |
| | else: |
| | colors, depths = renders, None |
| |
|
| | if cfg.use_bilateral_grid: |
| | grid_y, grid_x = torch.meshgrid( |
| | (torch.arange(height, device=self.device) + 0.5) / height, |
| | (torch.arange(width, device=self.device) + 0.5) / width, |
| | indexing="ij", |
| | ) |
| | grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) |
| | colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] |
| |
|
| | if cfg.random_bkgd: |
| | bkgd = torch.rand(1, 3, device=device) |
| | colors = colors + bkgd * (1.0 - alphas) |
| |
|
| | self.cfg.strategy.step_pre_backward( |
| | params=self.splats, |
| | optimizers=self.optimizers, |
| | state=self.strategy_state, |
| | step=step, |
| | info=info, |
| | ) |
| | |
| | |
| | l1loss = F.l1_loss(colors, pixels) |
| | ssimloss = 1.0 - fused_ssim( |
| | colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" |
| | ) |
| | loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda |
| | if cfg.depth_loss: |
| | |
| | points = torch.stack( |
| | [ |
| | points[:, :, 0] / (width - 1) * 2 - 1, |
| | points[:, :, 1] / (height - 1) * 2 - 1, |
| | ], |
| | dim=-1, |
| | ) |
| | grid = points.unsqueeze(2) |
| | depths = F.grid_sample( |
| | depths.permute(0, 3, 1, 2), grid, align_corners=True |
| | ) |
| | depths = depths.squeeze(3).squeeze(1) |
| | |
| | disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) |
| | disp_gt = 1.0 / depths_gt |
| | depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale |
| | loss += depthloss * cfg.depth_lambda |
| | if cfg.use_bilateral_grid: |
| | tvloss = 10 * total_variation_loss(self.bil_grids.grids) |
| | loss += tvloss |
| |
|
| | |
| | if cfg.opacity_reg > 0.0: |
| | loss = ( |
| | loss |
| | + cfg.opacity_reg |
| | * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() |
| | ) |
| | if cfg.scale_reg > 0.0: |
| | loss = ( |
| | loss |
| | + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() |
| | ) |
| |
|
| | loss.backward() |
| |
|
| | desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " |
| | if cfg.depth_loss: |
| | desc += f"depth loss={depthloss.item():.6f}| " |
| | if cfg.pose_opt and cfg.pose_noise: |
| | |
| | pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) |
| | desc += f"pose err={pose_err.item():.6f}| " |
| | pbar.set_description(desc) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: |
| | mem = torch.cuda.max_memory_allocated() / 1024**3 |
| | self.writer.add_scalar("train/loss", loss.item(), step) |
| | self.writer.add_scalar("train/l1loss", l1loss.item(), step) |
| | self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) |
| | self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step) |
| | self.writer.add_scalar("train/mem", mem, step) |
| | if cfg.depth_loss: |
| | self.writer.add_scalar("train/depthloss", depthloss.item(), step) |
| | if cfg.use_bilateral_grid: |
| | self.writer.add_scalar("train/tvloss", tvloss.item(), step) |
| | if cfg.tb_save_image: |
| | canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() |
| | canvas = canvas.reshape(-1, *canvas.shape[2:]) |
| | self.writer.add_image("train/render", canvas, step) |
| | self.writer.flush() |
| |
|
| | |
| | if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: |
| | mem = torch.cuda.max_memory_allocated() / 1024**3 |
| | stats = { |
| | "mem": mem, |
| | "ellipse_time": time.time() - global_tic, |
| | "num_GS": len(self.splats["means"]), |
| | } |
| | print("Step: ", step, stats) |
| | with open( |
| | f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", |
| | "w", |
| | ) as f: |
| | json.dump(stats, f) |
| | data = {"step": step, "splats": self.splats.state_dict()} |
| | if cfg.pose_opt: |
| | if world_size > 1: |
| | data["pose_adjust"] = self.pose_adjust.module.state_dict() |
| | else: |
| | data["pose_adjust"] = self.pose_adjust.state_dict() |
| | if cfg.app_opt: |
| | if world_size > 1: |
| | data["app_module"] = self.app_module.module.state_dict() |
| | else: |
| | data["app_module"] = self.app_module.state_dict() |
| | torch.save( |
| | data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" |
| | ) |
| | if ( |
| | step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1 |
| | ) and cfg.save_ply: |
| |
|
| | if self.cfg.app_opt: |
| | |
| | rgb = self.app_module( |
| | features=self.splats["features"], |
| | embed_ids=None, |
| | dirs=torch.zeros_like(self.splats["means"][None, :, :]), |
| | sh_degree=sh_degree_to_use, |
| | ) |
| | rgb = rgb + self.splats["colors"] |
| | rgb = torch.sigmoid(rgb).squeeze(0).unsqueeze(1) |
| | sh0 = rgb_to_sh(rgb) |
| | shN = torch.empty([sh0.shape[0], 0, 3], device=sh0.device) |
| | else: |
| | sh0 = self.splats["sh0"] |
| | shN = self.splats["shN"] |
| | |
| |
|
| | means = self.splats["means"] |
| | scales = self.splats["scales"] |
| | quats = self.splats["quats"] |
| | opacities = self.splats["opacities"] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | export_ply( |
| | means=means, |
| | scales=scales, |
| | rotations=quats, |
| | harmonics=torch.cat([sh0, shN], dim=1).permute(0, 2, 1), |
| | opacities=opacities.sigmoid(), |
| | path=Path(f"{self.ply_dir}/point_cloud_{step}.ply"), |
| | ) |
| |
|
| | |
| | if cfg.sparse_grad: |
| | assert cfg.packed, "Sparse gradients only work with packed mode." |
| | gaussian_ids = info["gaussian_ids"] |
| | for k in self.splats.keys(): |
| | grad = self.splats[k].grad |
| | if grad is None or grad.is_sparse: |
| | continue |
| | self.splats[k].grad = torch.sparse_coo_tensor( |
| | indices=gaussian_ids[None], |
| | values=grad[gaussian_ids], |
| | size=self.splats[k].size(), |
| | is_coalesced=len(Ks) == 1, |
| | ) |
| |
|
| | if cfg.visible_adam: |
| | gaussian_cnt = self.splats.means.shape[0] |
| | if cfg.packed: |
| | visibility_mask = torch.zeros_like( |
| | self.splats["opacities"], dtype=bool |
| | ) |
| | visibility_mask.scatter_(0, info["gaussian_ids"], 1) |
| | else: |
| | visibility_mask = (info["radii"] > 0).all(-1).any(0) |
| |
|
| | |
| | for optimizer in self.optimizers.values(): |
| | if cfg.visible_adam: |
| | optimizer.step(visibility_mask) |
| | else: |
| | optimizer.step() |
| | optimizer.zero_grad(set_to_none=True) |
| | for optimizer in self.pose_optimizers: |
| | optimizer.step() |
| | optimizer.zero_grad(set_to_none=True) |
| | for optimizer in self.app_optimizers: |
| | optimizer.step() |
| | optimizer.zero_grad(set_to_none=True) |
| | for optimizer in self.bil_grid_optimizers: |
| | optimizer.step() |
| | optimizer.zero_grad(set_to_none=True) |
| | for scheduler in schedulers: |
| | scheduler.step() |
| | |
| | |
| | if isinstance(self.cfg.strategy, DefaultStrategy): |
| | self.cfg.strategy.step_post_backward( |
| | params=self.splats, |
| | optimizers=self.optimizers, |
| | state=self.strategy_state, |
| | step=step, |
| | info=info, |
| | packed=cfg.packed, |
| | ) |
| | elif isinstance(self.cfg.strategy, MCMCStrategy): |
| | self.cfg.strategy.step_post_backward( |
| | params=self.splats, |
| | optimizers=self.optimizers, |
| | state=self.strategy_state, |
| | step=step, |
| | info=info, |
| | lr=schedulers[0].get_last_lr()[0], |
| | ) |
| | else: |
| | assert_never(self.cfg.strategy) |
| |
|
| | |
| | if step in [i - 1 for i in cfg.eval_steps]: |
| | self.eval(step) |
| | |
| |
|
| | |
| | if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: |
| | self.run_compression(step=step) |
| |
|
| | if not cfg.disable_viewer: |
| | self.viewer.lock.release() |
| | num_train_steps_per_sec = 1.0 / (time.time() - tic) |
| | num_train_rays_per_sec = ( |
| | num_train_rays_per_step * num_train_steps_per_sec |
| | ) |
| | |
| | self.viewer.render_tab_state.num_train_rays_per_sec = ( |
| | num_train_rays_per_sec |
| | ) |
| | |
| | self.viewer.update(step, num_train_rays_per_step) |
| |
|
| | @torch.no_grad() |
| | def eval(self, step: int, stage: str = "val"): |
| | """Entry for evaluation.""" |
| | print("Running evaluation...") |
| | cfg = self.cfg |
| | device = self.device |
| | world_rank = self.world_rank |
| | world_size = self.world_size |
| |
|
| | valloader = torch.utils.data.DataLoader( |
| | self.valset, batch_size=1, shuffle=False, num_workers=1 |
| | ) |
| | ellipse_time = 0 |
| | metrics = defaultdict(list) |
| | for i, data in enumerate(valloader): |
| | camtoworlds = data["camtoworld"].to(device) |
| | Ks = data["K"].to(device) |
| | pixels = data["image"].to(device) / 255.0 |
| | masks = data["mask"].to(device) if "mask" in data else None |
| | height, width = pixels.shape[1:3] |
| |
|
| | torch.cuda.synchronize() |
| | tic = time.time() |
| | render_colors, _, _ = self.rasterize_splats( |
| | camtoworlds=camtoworlds, |
| | Ks=Ks, |
| | width=width, |
| | height=height, |
| | sh_degree=cfg.sh_degree, |
| | near_plane=cfg.near_plane, |
| | far_plane=cfg.far_plane, |
| | |
| | render_mode="RGB+ED", |
| | masks=masks, |
| | ) |
| | torch.cuda.synchronize() |
| | ellipse_time += time.time() - tic |
| |
|
| | colors = render_colors[..., :3] |
| | depths = render_colors[..., 3] |
| |
|
| | colors = torch.clamp(colors, 0.0, 1.0) |
| | canvas_list = [pixels, colors] |
| |
|
| | if world_rank == 0: |
| | |
| | canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() |
| | canvas = (canvas * 255).astype(np.uint8) |
| | imageio.imwrite( |
| | f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", |
| | canvas, |
| | ) |
| | torchvision.utils.save_image(pixels.permute(0, 3, 1, 2), f"{self.render_dir}/gt_rgb_{stage}_step{step}_{i:04d}.png") |
| | torchvision.utils.save_image(colors.permute(0, 3, 1, 2), f"{self.render_dir}/render_rgb_{stage}_step{step}_{i:04d}.png") |
| | |
| | |
| |
|
| | pixels_p = pixels.permute(0, 3, 1, 2) |
| | colors_p = colors.permute(0, 3, 1, 2) |
| | |
| | metrics["psnr"].append(self.psnr(colors_p, pixels_p)) |
| | metrics["ssim"].append(self.ssim(colors_p, pixels_p)) |
| | metrics["lpips"].append(self.lpips(colors_p, pixels_p)) |
| | if cfg.use_bilateral_grid: |
| | cc_colors = color_correct(colors, pixels) |
| | cc_colors_p = cc_colors.permute(0, 3, 1, 2) |
| | metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) |
| |
|
| | if world_rank == 0: |
| | ellipse_time /= len(valloader) |
| |
|
| | stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} |
| | stats.update( |
| | { |
| | "ellipse_time": ellipse_time, |
| | "num_GS": len(self.splats["means"]), |
| | } |
| | ) |
| | print( |
| | f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " |
| | f"Time: {stats['ellipse_time']:.3f}s/image " |
| | f"Number of GS: {stats['num_GS']}" |
| | ) |
| | |
| | with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f: |
| | json.dump(stats, f) |
| | |
| | for k, v in stats.items(): |
| | self.writer.add_scalar(f"{stage}/{k}", v, step) |
| | self.writer.flush() |
| |
|
| | @torch.no_grad() |
| | def render_traj(self, step: int): |
| | """Entry for trajectory rendering.""" |
| | if self.cfg.disable_video: |
| | return |
| | print("Running trajectory rendering...") |
| | cfg = self.cfg |
| | device = self.device |
| |
|
| | camtoworlds_all = self.parser.camtoworlds[5:-5] |
| | if cfg.render_traj_path == "interp": |
| | camtoworlds_all = generate_interpolated_path( |
| | camtoworlds_all, 1 |
| | ) |
| | elif cfg.render_traj_path == "ellipse": |
| | height = camtoworlds_all[:, 2, 3].mean() |
| | camtoworlds_all = generate_ellipse_path_z( |
| | camtoworlds_all, height=height |
| | ) |
| | elif cfg.render_traj_path == "spiral": |
| | camtoworlds_all = generate_spiral_path( |
| | camtoworlds_all, |
| | bounds=self.parser.bounds * self.scene_scale, |
| | spiral_scale_r=self.parser.extconf["spiral_radius_scale"], |
| | ) |
| | else: |
| | raise ValueError( |
| | f"Render trajectory type not supported: {cfg.render_traj_path}" |
| | ) |
| |
|
| | camtoworlds_all = np.concatenate( |
| | [ |
| | camtoworlds_all, |
| | np.repeat( |
| | np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 |
| | ), |
| | ], |
| | axis=1, |
| | ) |
| |
|
| | camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) |
| | K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) |
| | width, height = list(self.parser.imsize_dict.values())[0] |
| |
|
| | |
| | video_dir = f"{cfg.result_dir}/videos" |
| | os.makedirs(video_dir, exist_ok=True) |
| | writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) |
| | for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): |
| | camtoworlds = camtoworlds_all[i : i + 1] |
| | Ks = K[None] |
| |
|
| | renders, _, _ = self.rasterize_splats( |
| | camtoworlds=camtoworlds, |
| | Ks=Ks, |
| | width=width, |
| | height=height, |
| | sh_degree=cfg.sh_degree, |
| | near_plane=cfg.near_plane, |
| | far_plane=cfg.far_plane, |
| | render_mode="RGB+ED", |
| | ) |
| | colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) |
| | depths = renders[..., 3:4] |
| | depths = (depths - depths.min()) / (depths.max() - depths.min()) |
| | canvas_list = [colors, depths.repeat(1, 1, 1, 3)] |
| |
|
| | |
| | canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() |
| | canvas = (canvas * 255).astype(np.uint8) |
| | writer.append_data(canvas) |
| | writer.close() |
| | print(f"Video saved to {video_dir}/traj_{step}.mp4") |
| |
|
| | @torch.no_grad() |
| | def run_compression(self, step: int): |
| | """Entry for running compression.""" |
| | print("Running compression...") |
| | world_rank = self.world_rank |
| |
|
| | compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" |
| | os.makedirs(compress_dir, exist_ok=True) |
| |
|
| | self.compression_method.compress(compress_dir, self.splats) |
| |
|
| | |
| | splats_c = self.compression_method.decompress(compress_dir) |
| | for k in splats_c.keys(): |
| | self.splats[k].data = splats_c[k].to(self.device) |
| | self.eval(step=step, stage="compress") |
| |
|
| | @torch.no_grad() |
| | def _viewer_render_fn( |
| | self, camera_state: CameraState, render_tab_state: RenderTabState |
| | ): |
| | assert isinstance(render_tab_state, GsplatRenderTabState) |
| | if render_tab_state.preview_render: |
| | width = render_tab_state.render_width |
| | height = render_tab_state.render_height |
| | else: |
| | width = render_tab_state.viewer_width |
| | height = render_tab_state.viewer_height |
| | c2w = camera_state.c2w |
| | K = camera_state.get_K((width, height)) |
| | c2w = torch.from_numpy(c2w).float().to(self.device) |
| | K = torch.from_numpy(K).float().to(self.device) |
| |
|
| | RENDER_MODE_MAP = { |
| | "rgb": "RGB", |
| | "depth(accumulated)": "D", |
| | "depth(expected)": "ED", |
| | "alpha": "RGB", |
| | } |
| |
|
| | render_colors, render_alphas, info = self.rasterize_splats( |
| | camtoworlds=c2w[None], |
| | Ks=K[None], |
| | width=width, |
| | height=height, |
| | sh_degree=min(render_tab_state.max_sh_degree, self.cfg.sh_degree), |
| | near_plane=render_tab_state.near_plane, |
| | far_plane=render_tab_state.far_plane, |
| | radius_clip=render_tab_state.radius_clip, |
| | |
| | eps2d=render_tab_state.eps2d, |
| | backgrounds=torch.tensor([render_tab_state.backgrounds], device=self.device) |
| | / 255.0, |
| | render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], |
| | rasterize_mode=render_tab_state.rasterize_mode, |
| | camera_model=render_tab_state.camera_model, |
| | ) |
| | render_tab_state.total_gs_count = len(self.splats["means"]) |
| | render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() |
| |
|
| | if render_tab_state.render_mode == "rgb": |
| | |
| | render_colors = render_colors[0, ..., 0:3].clamp(0, 1) |
| | renders = render_colors.cpu().numpy() |
| | elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: |
| | |
| | depth = render_colors[0, ..., 0:1] |
| | if render_tab_state.normalize_nearfar: |
| | near_plane = render_tab_state.near_plane |
| | far_plane = render_tab_state.far_plane |
| | else: |
| | near_plane = depth.min() |
| | far_plane = depth.max() |
| | depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10) |
| | depth_norm = torch.clip(depth_norm, 0, 1) |
| | if render_tab_state.inverse: |
| | depth_norm = 1 - depth_norm |
| | renders = ( |
| | apply_float_colormap(depth_norm, render_tab_state.colormap) |
| | .cpu() |
| | .numpy() |
| | ) |
| | elif render_tab_state.render_mode == "alpha": |
| | alpha = render_alphas[0, ..., 0:1] |
| | if render_tab_state.inverse: |
| | alpha = 1 - alpha |
| | renders = ( |
| | apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy() |
| | ) |
| | return renders |
| |
|
| |
|
| | def main(local_rank: int, world_rank, world_size: int, cfg: Config): |
| | if world_size > 1 and not cfg.disable_viewer: |
| | cfg.disable_viewer = True |
| | if world_rank == 0: |
| | print("Viewer is disabled in distributed training.") |
| |
|
| | runner = Runner(local_rank, world_rank, world_size, cfg) |
| |
|
| | if cfg.ckpt is not None: |
| | |
| | ckpts = [ |
| | torch.load(file, map_location=runner.device, weights_only=True) |
| | for file in cfg.ckpt |
| | ] |
| | for k in runner.splats.keys(): |
| | runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) |
| | step = ckpts[0]["step"] |
| | runner.eval(step=step) |
| | |
| | if cfg.compression is not None: |
| | runner.run_compression(step=step) |
| | else: |
| | runner.train() |
| | runner.eval(step=runner.cfg.max_steps) |
| | |
| | print("Training complete.") |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | if __name__ == "__main__": |
| | """ |
| | Usage: |
| | |
| | ```bash |
| | # Single GPU training |
| | CUDA_VISIBLE_DEVICES=9 python -m examples.simple_trainer default |
| | |
| | # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. |
| | CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 |
| | |
| | """ |
| |
|
| | |
| | |
| | configs = { |
| | "default": ( |
| | "Gaussian splatting training using densification heuristics from the original paper.", |
| | Config( |
| | strategy=DefaultStrategy(verbose=True), |
| | ), |
| | ), |
| | "mcmc": ( |
| | "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", |
| | Config( |
| | init_opa=0.5, |
| | init_scale=0.1, |
| | opacity_reg=0.01, |
| | scale_reg=0.01, |
| | strategy=MCMCStrategy(verbose=True), |
| | ), |
| | ), |
| | } |
| | cfg = tyro.extras.overridable_config_cli(configs) |
| | cfg.adjust_steps(cfg.steps_scaler) |
| |
|
| | |
| | if cfg.compression == "png": |
| | try: |
| | import plas |
| | import torchpq |
| | except: |
| | raise ImportError( |
| | "To use PNG compression, you need to install " |
| | "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " |
| | "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " |
| | ) |
| |
|
| | cli(main, cfg, verbose=True) |
| |
|