| import json |
| import math |
| import os |
| import time |
| from collections import defaultdict |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| import imageio |
| import matplotlib |
| import torchvision |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import tqdm |
| import tyro |
| import viser |
| import yaml |
| import torchvision |
| import sys |
| from plyfile import PlyData, PlyElement |
|
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
|
|
| from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
| from src.model.types import Gaussians |
| from src.post_opt.datasets.colmap import Dataset, Parser |
| from src.post_opt.datasets.traj import ( |
| generate_ellipse_path_z, |
| generate_interpolated_path, |
| generate_spiral_path, |
| ) |
| from fused_ssim import fused_ssim |
|
|
| from src.utils.image import process_image |
| from src.post_opt.exporter import export_splats |
| from src.post_opt.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss |
| from torch import Tensor |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.tensorboard import SummaryWriter |
| from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
| from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
| from typing_extensions import Literal, assert_never |
| from src.post_opt.utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed |
|
|
| |
| from gsplat.compression import PngCompression |
| from gsplat.distributed import cli |
| |
| |
| from gsplat import rasterization |
| from gsplat.strategy import DefaultStrategy, MCMCStrategy |
| from src.post_opt.gsplat_viewer import GsplatViewer, GsplatRenderTabState |
| from nerfview import CameraState, RenderTabState, apply_float_colormap |
|
|
| import torch |
| from einops import rearrange |
| from jaxtyping import Float |
| from torch import Tensor |
| from scipy.spatial.transform import Rotation as R |
|
|
| from src.model.model.anysplat import AnySplat |
|
|
|
|
| |
| def quaternion_to_matrix( |
| quaternions: Float[Tensor, "*batch 4"], |
| eps: float = 1e-8, |
| ) -> Float[Tensor, "*batch 3 3"]: |
| |
| i, j, k, r = torch.unbind(quaternions, dim=-1) |
| two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) |
|
|
| o = torch.stack( |
| ( |
| 1 - two_s * (j * j + k * k), |
| two_s * (i * j - k * r), |
| two_s * (i * k + j * r), |
| two_s * (i * j + k * r), |
| 1 - two_s * (i * i + k * k), |
| two_s * (j * k - i * r), |
| two_s * (i * k - j * r), |
| two_s * (j * k + i * r), |
| 1 - two_s * (i * i + j * j), |
| ), |
| -1, |
| ) |
| return rearrange(o, "... (i j) -> ... i j", i=3, j=3) |
|
|
| def construct_list_of_attributes(num_rest: int) -> list[str]: |
| attributes = ["x", "y", "z", "nx", "ny", "nz"] |
| for i in range(3): |
| attributes.append(f"f_dc_{i}") |
| for i in range(num_rest): |
| attributes.append(f"f_rest_{i}") |
| attributes.append("opacity") |
| for i in range(3): |
| attributes.append(f"scale_{i}") |
| for i in range(4): |
| attributes.append(f"rot_{i}") |
| return attributes |
|
|
| def export_ply( |
| means: Float[Tensor, "gaussian 3"], |
| scales: Float[Tensor, "gaussian 3"], |
| rotations: Float[Tensor, "gaussian 4"], |
| harmonics: Float[Tensor, "gaussian 3 d_sh"], |
| opacities: Float[Tensor, " gaussian"], |
| path: Path, |
| shift_and_scale: bool = False, |
| save_sh_dc_only: bool = True, |
| ): |
| if shift_and_scale: |
| |
| means = means - means.median(dim=0).values |
|
|
| |
| scale_factor = means.abs().quantile(0.95, dim=0).max() |
| means = means / scale_factor |
| scales = scales / scale_factor |
|
|
| |
| rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() |
| rotations = R.from_matrix(rotations).as_quat() |
| x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") |
| rotations = np.stack((w, x, y, z), axis=-1) |
|
|
| |
| |
| f_dc = harmonics[..., 0] |
| f_rest = harmonics[..., 1:].flatten(start_dim=1) |
|
|
| dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])] |
| elements = np.empty(means.shape[0], dtype=dtype_full) |
| attributes = [ |
| means.detach().cpu().numpy(), |
| torch.zeros_like(means).detach().cpu().numpy(), |
| f_dc.detach().cpu().contiguous().numpy(), |
| f_rest.detach().cpu().contiguous().numpy(), |
| opacities[..., None].detach().cpu().numpy(), |
| scales.detach().cpu().numpy(), |
| rotations, |
| ] |
| if save_sh_dc_only: |
| |
| attributes.pop(3) |
|
|
| attributes = np.concatenate(attributes, axis=1) |
| elements[:] = list(map(tuple, attributes)) |
| path.parent.mkdir(exist_ok=True, parents=True) |
| PlyData([PlyElement.describe(elements, "vertex")]).write(path) |
|
|
| def colorize_depth_maps(depth_map, min_depth=0.0, max_depth=1.0, cmap="Spectral", valid_mask=None): |
| """ |
| Colorize depth maps. |
| """ |
| assert len(depth_map.shape) >= 2, "Invalid dimension" |
|
|
| if isinstance(depth_map, torch.Tensor): |
| depth = depth_map.detach().clone().squeeze().numpy() |
| elif isinstance(depth_map, np.ndarray): |
| depth = depth_map.copy().squeeze() |
| |
| if depth.ndim < 3: |
| depth = depth[np.newaxis, :, :] |
| |
| |
| cm = matplotlib.colormaps[cmap] |
| |
| depth = ((depth - depth.min()) / (depth.max() - depth.min())).clip(0, 1) |
| img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] |
| img_colored_np = np.rollaxis(img_colored_np, 3, 1) |
|
|
| if valid_mask is not None: |
| if isinstance(depth_map, torch.Tensor): |
| valid_mask = valid_mask.detach().numpy() |
| valid_mask = valid_mask.squeeze() |
| if valid_mask.ndim < 3: |
| valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] |
| else: |
| valid_mask = valid_mask[:, np.newaxis, :, :] |
| valid_mask = np.repeat(valid_mask, 3, axis=1) |
| img_colored_np[~valid_mask] = 0 |
|
|
| if isinstance(depth_map, torch.Tensor): |
| img_colored = torch.from_numpy(img_colored_np).float() |
| elif isinstance(depth_map, np.ndarray): |
| img_colored = img_colored_np |
|
|
| return img_colored |
|
|
| def build_covariance( |
| scale: Float[Tensor, "*#batch 3"], |
| rotation_xyzw: Float[Tensor, "*#batch 4"], |
| ) -> Float[Tensor, "*batch 3 3"]: |
| scale = scale.diag_embed() |
| rotation = quaternion_to_matrix(rotation_xyzw) |
| return ( |
| rotation |
| @ scale |
| @ rearrange(scale, "... i j -> ... j i") |
| @ rearrange(rotation, "... i j -> ... j i") |
| ) |
|
|
|
|
| @dataclass |
| class Config: |
| |
| disable_viewer: bool = True |
| |
| ckpt: Optional[List[str]] = None |
| |
| compression: Optional[Literal["png"]] = None |
| |
| render_traj_path: str = "interp" |
|
|
| data_dir: str = "data/360_v2/garden" |
| |
| data_factor: int = 4 |
| |
| result_dir: str = "results/garden" |
| |
| test_every: int = 8 |
| |
| patch_size: Optional[int] = None |
| |
| global_scale: float = 1.0 |
| |
| normalize_world_space: bool = True |
| |
| camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" |
|
|
| |
| port: int = 8080 |
|
|
| |
| batch_size: int = 1 |
| |
| steps_scaler: float = 1.0 |
|
|
| |
| max_steps: int = 3_000 |
| |
| eval_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) |
| |
| save_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) |
| |
| save_ply: bool = False |
| |
| ply_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) |
| |
| disable_video: bool = False |
| |
| |
| init_type: str = "sfm" |
| |
| init_num_pts: int = 100_000 |
| |
| init_extent: float = 3.0 |
| |
| sh_degree: int = 4 |
| |
| sh_degree_interval: int = 1000 |
| |
| init_opa: float = 0.1 |
| |
| init_scale: float = 1.0 |
| |
| ssim_lambda: float = 0.2 |
|
|
| |
| near_plane: float = 1e-10 |
| |
| far_plane: float = 1e10 |
|
|
| |
| strategy: Union[DefaultStrategy, MCMCStrategy] = field( |
| default_factory=DefaultStrategy |
| ) |
| |
| packed: bool = False |
| |
| sparse_grad: bool = False |
| |
| visible_adam: bool = False |
| |
| antialiased: bool = False |
|
|
| |
| random_bkgd: bool = False |
|
|
| |
| opacity_reg: float = 0.0 |
| |
| scale_reg: float = 0.0 |
|
|
| |
| pose_opt: bool = True |
| |
| pose_opt_lr: float = 1e-5 |
| |
| pose_opt_reg: float = 1e-6 |
| |
| pose_noise: float = 0.0 |
|
|
| |
| app_opt: bool = False |
| |
| app_embed_dim: int = 16 |
| |
| app_opt_lr: float = 1e-3 |
| |
| app_opt_reg: float = 1e-6 |
|
|
| |
| use_bilateral_grid: bool = False |
| |
| bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) |
|
|
| |
| depth_loss: bool = False |
| |
| depth_lambda: float = 1e-2 |
|
|
| |
| tb_every: int = 100 |
| |
| tb_save_image: bool = False |
|
|
| lpips_net: Literal["vgg", "alex"] = "vgg" |
|
|
| lr_means: float = 1.6e-4 |
| lr_scales: float = 5e-3 |
| lr_quats: float = 1e-3 |
| lr_opacities: float = 5e-2 |
| lr_sh: float = 2.5e-3 |
|
|
| def adjust_steps(self, factor: float): |
| self.eval_steps = [int(i * factor) for i in self.eval_steps] |
| self.save_steps = [int(i * factor) for i in self.save_steps] |
| self.ply_steps = [int(i * factor) for i in self.ply_steps] |
| self.max_steps = int(self.max_steps * factor) |
| self.sh_degree_interval = int(self.sh_degree_interval * factor) |
|
|
| strategy = self.strategy |
| if isinstance(strategy, DefaultStrategy): |
| |
| |
| |
| |
|
|
| strategy.refine_start_iter = 30000 |
| strategy.refine_stop_iter = 0 |
| strategy.reset_every = 30000 |
| strategy.refine_every = 30000 |
|
|
| elif isinstance(strategy, MCMCStrategy): |
| strategy.refine_start_iter = int(strategy.refine_start_iter * factor) |
| strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) |
| strategy.refine_every = int(strategy.refine_every * factor) |
| else: |
| assert_never(strategy) |
|
|
|
|
| def create_splats_with_optimizers( |
| gaussians: Gaussians, |
| init_num_pts: int = 100_000, |
| init_extent: float = 3.0, |
| init_opacity: float = 0.1, |
| init_scale: float = 1.0, |
| sh_degree: int = 3, |
| sparse_grad: bool = False, |
| visible_adam: bool = False, |
| batch_size: int = 1, |
| feature_dim: Optional[int] = None, |
| device: str = "cuda", |
| world_rank: int = 0, |
| world_size: int = 1, |
| cfg: Config = None, |
| ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: |
|
|
| points = gaussians.means[0].detach().float() |
| scales = torch.log(gaussians.scales[0].detach().float()) |
| quats = gaussians.rotations[0].detach().float() |
| opacities = torch.logit(gaussians.opacities[0].detach().float()) |
| harmonics = gaussians.harmonics[0].detach().float().permute(0, 2, 1).contiguous() |
|
|
| N = points.shape[0] |
| |
| scene_scale = 1.0 |
| masks = opacities.sigmoid() > 0.01 |
| harmonics = harmonics[masks] |
| params = [ |
| |
| ("means", torch.nn.Parameter(points[masks]), cfg.lr_means * scene_scale), |
| ("scales", torch.nn.Parameter(scales[masks]), cfg.lr_scales), |
| ("quats", torch.nn.Parameter(quats[masks]), cfg.lr_quats), |
| ("opacities", torch.nn.Parameter(opacities[masks]), cfg.lr_opacities), |
| ] |
| |
| params.append(("sh0", torch.nn.Parameter(harmonics[:, :1, :]), cfg.lr_sh)) |
| params.append(("shN", torch.nn.Parameter(harmonics[:, 1:, :]), cfg.lr_sh/20)) |
|
|
| splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) |
| |
| |
| |
| |
| BS = batch_size * world_size |
| optimizer_class = None |
| if sparse_grad: |
| optimizer_class = torch.optim.SparseAdam |
| elif visible_adam: |
| optimizer_class = SelectiveAdam |
| else: |
| optimizer_class = torch.optim.Adam |
| optimizers = { |
| name: optimizer_class( |
| [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], |
| eps=1e-15 / math.sqrt(BS), |
| |
| betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), |
| ) |
| for name, _, lr in params |
| } |
| return splats, optimizers |
|
|
|
|
| class Runner: |
| """Engine for training and testing.""" |
|
|
| def __init__( |
| self, local_rank: int, world_rank, world_size: int, cfg: Config |
| ) -> None: |
| set_random_seed(42 + local_rank) |
|
|
| self.cfg = cfg |
| self.world_rank = world_rank |
| self.local_rank = local_rank |
| self.world_size = world_size |
| self.device = f"cuda:{local_rank}" |
|
|
| |
| os.makedirs(cfg.result_dir, exist_ok=True) |
|
|
| |
| self.ckpt_dir = f"{cfg.result_dir}/ckpts" |
| os.makedirs(self.ckpt_dir, exist_ok=True) |
| self.stats_dir = f"{cfg.result_dir}/stats" |
| os.makedirs(self.stats_dir, exist_ok=True) |
| self.render_dir = f"{cfg.result_dir}/renders" |
| os.makedirs(self.render_dir, exist_ok=True) |
| self.ply_dir = f"{cfg.result_dir}/ply" |
| os.makedirs(self.ply_dir, exist_ok=True) |
|
|
| |
| self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") |
| |
| |
| model = AnySplat.from_pretrained("lhjiang/anysplat") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| model.eval() |
| for param in model.parameters(): |
| param.requires_grad = False |
| |
| image_folder = cfg.data_dir |
| image_names = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) |
| images = [process_image(img_path) for img_path in image_names] |
| ctx_indices = [idx for idx, name in enumerate(image_names) if idx % cfg.test_every != 0] |
| tgt_indices = [idx for idx, name in enumerate(image_names) if idx % cfg.test_every == 0] |
| |
| ctx_images = torch.stack([images[i] for i in ctx_indices], dim=0).unsqueeze(0).to(device) |
| tgt_images = torch.stack([images[i] for i in tgt_indices], dim=0).unsqueeze(0).to(device) |
| ctx_images = (ctx_images+1)*0.5 |
| tgt_images = (tgt_images+1)*0.5 |
| b, v, _, h, w = tgt_images.shape |
|
|
| |
| encoder_output = model.encoder( |
| ctx_images, |
| global_step=0, |
| visualization_dump={}, |
| ) |
| gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose |
| |
| num_context_view = ctx_images.shape[1] |
| vggt_input_image = torch.cat((ctx_images, tgt_images), dim=1).to(torch.bfloat16) |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): |
| aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(vggt_input_image, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
| with torch.cuda.amp.autocast(enabled=False): |
| fp32_tokens = [token.float() for token in aggregated_tokens_list] |
| pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
| pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, vggt_input_image.shape[-2:]) |
|
|
| extrinsic_padding = torch.tensor([0, 0, 0, 1], device=pred_all_extrinsic.device, dtype=pred_all_extrinsic.dtype).view(1, 1, 1, 4).repeat(b, vggt_input_image.shape[1], 1, 1) |
| pred_all_extrinsic = torch.cat([pred_all_extrinsic, extrinsic_padding], dim=2).inverse() |
|
|
| pred_all_intrinsic[:, :, 0] = pred_all_intrinsic[:, :, 0] / w |
| pred_all_intrinsic[:, :, 1] = pred_all_intrinsic[:, :, 1] / h |
| pred_all_context_extrinsic, pred_all_target_extrinsic = pred_all_extrinsic[:, :num_context_view], pred_all_extrinsic[:, num_context_view:] |
| pred_all_context_intrinsic, pred_all_target_intrinsic = pred_all_intrinsic[:, :num_context_view], pred_all_intrinsic[:, num_context_view:] |
|
|
| scale_factor = pred_context_pose['extrinsic'][:, :, :3, 3].mean() / pred_all_context_extrinsic[:, :, :3, 3].mean() |
| pred_all_target_extrinsic[..., :3, 3] = pred_all_target_extrinsic[..., :3, 3] * scale_factor |
| pred_all_context_extrinsic[..., :3, 3] = pred_all_context_extrinsic[..., :3, 3] * scale_factor |
| print("scale_factor:", scale_factor) |
|
|
| |
| |
| |
| |
| |
| |
| |
| self.trainset = Dataset( |
| split="train", |
| images=ctx_images[0].detach().cpu().numpy(), |
| camtoworlds=pred_all_context_extrinsic[0].detach().cpu().numpy(), |
| Ks=pred_all_context_intrinsic[0].detach().cpu().numpy(), |
| patch_size=cfg.patch_size, |
| load_depths=cfg.depth_loss, |
| ) |
| self.valset = Dataset( |
| images=tgt_images[0].detach().cpu().numpy(), |
| camtoworlds=pred_all_target_extrinsic[0].detach().cpu().numpy(), |
| Ks=pred_all_target_intrinsic[0].detach().cpu().numpy(), |
| split="val" |
| ) |
|
|
| |
| feature_dim = 32 if cfg.app_opt else None |
| self.splats, self.optimizers = create_splats_with_optimizers( |
| gaussians=gaussians, |
| init_num_pts=cfg.init_num_pts, |
| init_extent=cfg.init_extent, |
| init_opacity=cfg.init_opa, |
| init_scale=cfg.init_scale, |
| sh_degree=cfg.sh_degree, |
| sparse_grad=cfg.sparse_grad, |
| visible_adam=cfg.visible_adam, |
| batch_size=cfg.batch_size, |
| feature_dim=feature_dim, |
| device=self.device, |
| world_rank=world_rank, |
| world_size=world_size, |
| cfg=cfg, |
| ) |
| print("Model initialized. Number of GS:", len(self.splats["means"])) |
|
|
| |
| self.cfg.strategy.check_sanity(self.splats, self.optimizers) |
|
|
| if isinstance(self.cfg.strategy, DefaultStrategy): |
| self.strategy_state = self.cfg.strategy.initialize_state( |
| scene_scale=1.0 |
| ) |
| elif isinstance(self.cfg.strategy, MCMCStrategy): |
| self.strategy_state = self.cfg.strategy.initialize_state() |
| else: |
| assert_never(self.cfg.strategy) |
|
|
| |
| self.compression_method = None |
| if cfg.compression is not None: |
| if cfg.compression == "png": |
| self.compression_method = PngCompression() |
| else: |
| raise ValueError(f"Unknown compression strategy: {cfg.compression}") |
|
|
| self.pose_optimizers = [] |
| if cfg.pose_opt: |
| self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) |
| self.pose_adjust.zero_init() |
| self.pose_optimizers = [ |
| torch.optim.Adam( |
| self.pose_adjust.parameters(), |
| lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), |
| weight_decay=cfg.pose_opt_reg, |
| ) |
| ] |
| if world_size > 1: |
| self.pose_adjust = DDP(self.pose_adjust) |
|
|
| if cfg.pose_noise > 0.0: |
| self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) |
| self.pose_perturb.random_init(cfg.pose_noise) |
| if world_size > 1: |
| self.pose_perturb = DDP(self.pose_perturb) |
|
|
| self.app_optimizers = [] |
| if cfg.app_opt: |
| assert feature_dim is not None |
| self.app_module = AppearanceOptModule( |
| len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree |
| ).to(self.device) |
| |
| torch.nn.init.zeros_(self.app_module.color_head[-1].weight) |
| torch.nn.init.zeros_(self.app_module.color_head[-1].bias) |
| self.app_optimizers = [ |
| torch.optim.Adam( |
| self.app_module.embeds.parameters(), |
| lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, |
| weight_decay=cfg.app_opt_reg, |
| ), |
| torch.optim.Adam( |
| self.app_module.color_head.parameters(), |
| lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), |
| ), |
| ] |
| if world_size > 1: |
| self.app_module = DDP(self.app_module) |
|
|
| self.bil_grid_optimizers = [] |
| if cfg.use_bilateral_grid: |
| self.bil_grids = BilateralGrid( |
| len(self.trainset), |
| grid_X=cfg.bilateral_grid_shape[0], |
| grid_Y=cfg.bilateral_grid_shape[1], |
| grid_W=cfg.bilateral_grid_shape[2], |
| ).to(self.device) |
| self.bil_grid_optimizers = [ |
| torch.optim.Adam( |
| self.bil_grids.parameters(), |
| lr=2e-3 * math.sqrt(cfg.batch_size), |
| eps=1e-15, |
| ), |
| ] |
|
|
| |
| self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) |
| self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) |
|
|
| if cfg.lpips_net == "alex": |
| self.lpips = LearnedPerceptualImagePatchSimilarity( |
| net_type="alex", normalize=True |
| ).to(self.device) |
| elif cfg.lpips_net == "vgg": |
| |
| self.lpips = LearnedPerceptualImagePatchSimilarity( |
| net_type="vgg", normalize=False |
| ).to(self.device) |
| else: |
| raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") |
|
|
| |
| if not self.cfg.disable_viewer: |
| self.server = viser.ViserServer(port=cfg.port, verbose=False) |
| self.viewer = GsplatViewer( |
| server=self.server, |
| render_fn=self._viewer_render_fn, |
| output_dir=Path(cfg.result_dir), |
| mode="training", |
| ) |
| |
| def rasterize_splats( |
| self, |
| camtoworlds: Tensor, |
| Ks: Tensor, |
| width: int, |
| height: int, |
| masks: Optional[Tensor] = None, |
| rasterize_mode: Optional[Literal["classic", "antialiased"]] = None, |
| camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None, |
| **kwargs, |
| ) -> Tuple[Tensor, Tensor, Dict]: |
| means = self.splats["means"] |
| |
| |
| quats = self.splats["quats"] |
| scales = torch.exp(self.splats["scales"]) |
| opacities = torch.sigmoid(self.splats["opacities"]) |
| |
| image_ids = kwargs.pop("image_ids", None) |
| if self.cfg.app_opt: |
| colors = self.app_module( |
| features=self.splats["features"], |
| embed_ids=image_ids, |
| dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], |
| sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), |
| ) |
| colors = colors + self.splats["colors"] |
| colors = torch.sigmoid(colors) |
| else: |
| colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) |
|
|
| if rasterize_mode is None: |
| rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" |
| if camera_model is None: |
| camera_model = self.cfg.camera_model |
| |
| |
| render_colors, render_alphas, info = rasterization( |
| means=means, |
| quats=quats, |
| scales=scales, |
| opacities=opacities, |
| colors=colors, |
| |
| viewmats=torch.linalg.inv(camtoworlds), |
| Ks=Ks, |
| width=width, |
| height=height, |
| packed=self.cfg.packed, |
| absgrad=( |
| self.cfg.strategy.absgrad |
| if isinstance(self.cfg.strategy, DefaultStrategy) |
| else False |
| ), |
| sparse_grad=self.cfg.sparse_grad, |
| rasterize_mode=rasterize_mode, |
| distributed=self.world_size > 1, |
| camera_model=self.cfg.camera_model, |
| radius_clip=0.1, |
| backgrounds=torch.tensor([0.0, 0.0, 0.0]).cuda().unsqueeze(0).repeat(1, 1), |
| **kwargs, |
| ) |
| if masks is not None: |
| render_colors[~masks] = 0 |
| return render_colors, render_alphas, info |
|
|
| def train(self): |
| cfg = self.cfg |
| device = self.device |
| world_rank = self.world_rank |
| world_size = self.world_size |
|
|
| |
| if world_rank == 0: |
| with open(f"{cfg.result_dir}/cfg.yml", "w") as f: |
| yaml.dump(vars(cfg), f) |
|
|
| max_steps = cfg.max_steps |
| init_step = 0 |
|
|
| schedulers = [ |
| |
| torch.optim.lr_scheduler.ExponentialLR( |
| self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) |
| ), |
| ] |
| if cfg.pose_opt: |
| |
| schedulers.append( |
| torch.optim.lr_scheduler.ExponentialLR( |
| self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) |
| ) |
| ) |
| if cfg.use_bilateral_grid: |
| |
| schedulers.append( |
| torch.optim.lr_scheduler.ChainedScheduler( |
| [ |
| torch.optim.lr_scheduler.LinearLR( |
| self.bil_grid_optimizers[0], |
| start_factor=0.01, |
| total_iters=1000, |
| ), |
| torch.optim.lr_scheduler.ExponentialLR( |
| self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps) |
| ), |
| ] |
| ) |
| ) |
|
|
| trainloader = torch.utils.data.DataLoader( |
| self.trainset, |
| batch_size=cfg.batch_size, |
| shuffle=True, |
| num_workers=4, |
| persistent_workers=True, |
| pin_memory=True, |
| ) |
| trainloader_iter = iter(trainloader) |
|
|
| |
| global_tic = time.time() |
| pbar = tqdm.tqdm(range(init_step, max_steps)) |
| for step in pbar: |
| if not cfg.disable_viewer: |
| while self.viewer.state == "paused": |
| time.sleep(0.01) |
| self.viewer.lock.acquire() |
| tic = time.time() |
|
|
| try: |
| data = next(trainloader_iter) |
| except StopIteration: |
| trainloader_iter = iter(trainloader) |
| data = next(trainloader_iter) |
|
|
| camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) |
| Ks = data["K"].to(device) |
| pixels = data["image"].to(device) / 255.0 |
| num_train_rays_per_step = ( |
| pixels.shape[0] * pixels.shape[1] * pixels.shape[2] |
| ) |
| image_ids = data["image_id"].to(device) |
| masks = data["mask"].to(device) if "mask" in data else None |
| if cfg.depth_loss: |
| points = data["points"].to(device) |
| depths_gt = data["depths"].to(device) |
|
|
| height, width = pixels.shape[1:3] |
|
|
| if cfg.pose_noise: |
| camtoworlds = self.pose_perturb(camtoworlds, image_ids) |
|
|
| if cfg.pose_opt: |
| camtoworlds = self.pose_adjust(camtoworlds, image_ids) |
|
|
| |
| |
| sh_degree_to_use = cfg.sh_degree |
|
|
| |
| renders, alphas, info = self.rasterize_splats( |
| camtoworlds=camtoworlds, |
| Ks=Ks, |
| width=width, |
| height=height, |
| sh_degree=sh_degree_to_use, |
| near_plane=cfg.near_plane, |
| far_plane=cfg.far_plane, |
| image_ids=image_ids, |
| render_mode="RGB+ED" if cfg.depth_loss else "RGB", |
| masks=masks, |
| ) |
| if renders.shape[-1] == 4: |
| colors, depths = renders[..., 0:3], renders[..., 3:4] |
| else: |
| colors, depths = renders, None |
|
|
| if cfg.use_bilateral_grid: |
| grid_y, grid_x = torch.meshgrid( |
| (torch.arange(height, device=self.device) + 0.5) / height, |
| (torch.arange(width, device=self.device) + 0.5) / width, |
| indexing="ij", |
| ) |
| grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) |
| colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] |
|
|
| if cfg.random_bkgd: |
| bkgd = torch.rand(1, 3, device=device) |
| colors = colors + bkgd * (1.0 - alphas) |
|
|
| self.cfg.strategy.step_pre_backward( |
| params=self.splats, |
| optimizers=self.optimizers, |
| state=self.strategy_state, |
| step=step, |
| info=info, |
| ) |
| |
| |
| l1loss = F.l1_loss(colors, pixels) |
| ssimloss = 1.0 - fused_ssim( |
| colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" |
| ) |
| loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda |
| if cfg.depth_loss: |
| |
| points = torch.stack( |
| [ |
| points[:, :, 0] / (width - 1) * 2 - 1, |
| points[:, :, 1] / (height - 1) * 2 - 1, |
| ], |
| dim=-1, |
| ) |
| grid = points.unsqueeze(2) |
| depths = F.grid_sample( |
| depths.permute(0, 3, 1, 2), grid, align_corners=True |
| ) |
| depths = depths.squeeze(3).squeeze(1) |
| |
| disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) |
| disp_gt = 1.0 / depths_gt |
| depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale |
| loss += depthloss * cfg.depth_lambda |
| if cfg.use_bilateral_grid: |
| tvloss = 10 * total_variation_loss(self.bil_grids.grids) |
| loss += tvloss |
|
|
| |
| if cfg.opacity_reg > 0.0: |
| loss = ( |
| loss |
| + cfg.opacity_reg |
| * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() |
| ) |
| if cfg.scale_reg > 0.0: |
| loss = ( |
| loss |
| + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() |
| ) |
|
|
| loss.backward() |
|
|
| desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " |
| if cfg.depth_loss: |
| desc += f"depth loss={depthloss.item():.6f}| " |
| if cfg.pose_opt and cfg.pose_noise: |
| |
| pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) |
| desc += f"pose err={pose_err.item():.6f}| " |
| pbar.set_description(desc) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: |
| mem = torch.cuda.max_memory_allocated() / 1024**3 |
| self.writer.add_scalar("train/loss", loss.item(), step) |
| self.writer.add_scalar("train/l1loss", l1loss.item(), step) |
| self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) |
| self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step) |
| self.writer.add_scalar("train/mem", mem, step) |
| if cfg.depth_loss: |
| self.writer.add_scalar("train/depthloss", depthloss.item(), step) |
| if cfg.use_bilateral_grid: |
| self.writer.add_scalar("train/tvloss", tvloss.item(), step) |
| if cfg.tb_save_image: |
| canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() |
| canvas = canvas.reshape(-1, *canvas.shape[2:]) |
| self.writer.add_image("train/render", canvas, step) |
| self.writer.flush() |
|
|
| |
| if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: |
| mem = torch.cuda.max_memory_allocated() / 1024**3 |
| stats = { |
| "mem": mem, |
| "ellipse_time": time.time() - global_tic, |
| "num_GS": len(self.splats["means"]), |
| } |
| print("Step: ", step, stats) |
| with open( |
| f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", |
| "w", |
| ) as f: |
| json.dump(stats, f) |
| data = {"step": step, "splats": self.splats.state_dict()} |
| if cfg.pose_opt: |
| if world_size > 1: |
| data["pose_adjust"] = self.pose_adjust.module.state_dict() |
| else: |
| data["pose_adjust"] = self.pose_adjust.state_dict() |
| if cfg.app_opt: |
| if world_size > 1: |
| data["app_module"] = self.app_module.module.state_dict() |
| else: |
| data["app_module"] = self.app_module.state_dict() |
| torch.save( |
| data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" |
| ) |
| if ( |
| step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1 |
| ) and cfg.save_ply: |
|
|
| if self.cfg.app_opt: |
| |
| rgb = self.app_module( |
| features=self.splats["features"], |
| embed_ids=None, |
| dirs=torch.zeros_like(self.splats["means"][None, :, :]), |
| sh_degree=sh_degree_to_use, |
| ) |
| rgb = rgb + self.splats["colors"] |
| rgb = torch.sigmoid(rgb).squeeze(0).unsqueeze(1) |
| sh0 = rgb_to_sh(rgb) |
| shN = torch.empty([sh0.shape[0], 0, 3], device=sh0.device) |
| else: |
| sh0 = self.splats["sh0"] |
| shN = self.splats["shN"] |
| |
|
|
| means = self.splats["means"] |
| scales = self.splats["scales"] |
| quats = self.splats["quats"] |
| opacities = self.splats["opacities"] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| export_ply( |
| means=means, |
| scales=scales, |
| rotations=quats, |
| harmonics=torch.cat([sh0, shN], dim=1).permute(0, 2, 1), |
| opacities=opacities.sigmoid(), |
| path=Path(f"{self.ply_dir}/point_cloud_{step}.ply"), |
| ) |
|
|
| |
| if cfg.sparse_grad: |
| assert cfg.packed, "Sparse gradients only work with packed mode." |
| gaussian_ids = info["gaussian_ids"] |
| for k in self.splats.keys(): |
| grad = self.splats[k].grad |
| if grad is None or grad.is_sparse: |
| continue |
| self.splats[k].grad = torch.sparse_coo_tensor( |
| indices=gaussian_ids[None], |
| values=grad[gaussian_ids], |
| size=self.splats[k].size(), |
| is_coalesced=len(Ks) == 1, |
| ) |
|
|
| if cfg.visible_adam: |
| gaussian_cnt = self.splats.means.shape[0] |
| if cfg.packed: |
| visibility_mask = torch.zeros_like( |
| self.splats["opacities"], dtype=bool |
| ) |
| visibility_mask.scatter_(0, info["gaussian_ids"], 1) |
| else: |
| visibility_mask = (info["radii"] > 0).all(-1).any(0) |
|
|
| |
| for optimizer in self.optimizers.values(): |
| if cfg.visible_adam: |
| optimizer.step(visibility_mask) |
| else: |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
| for optimizer in self.pose_optimizers: |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
| for optimizer in self.app_optimizers: |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
| for optimizer in self.bil_grid_optimizers: |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
| for scheduler in schedulers: |
| scheduler.step() |
| |
| |
| if isinstance(self.cfg.strategy, DefaultStrategy): |
| self.cfg.strategy.step_post_backward( |
| params=self.splats, |
| optimizers=self.optimizers, |
| state=self.strategy_state, |
| step=step, |
| info=info, |
| packed=cfg.packed, |
| ) |
| elif isinstance(self.cfg.strategy, MCMCStrategy): |
| self.cfg.strategy.step_post_backward( |
| params=self.splats, |
| optimizers=self.optimizers, |
| state=self.strategy_state, |
| step=step, |
| info=info, |
| lr=schedulers[0].get_last_lr()[0], |
| ) |
| else: |
| assert_never(self.cfg.strategy) |
|
|
| |
| if step in [i - 1 for i in cfg.eval_steps]: |
| self.eval(step) |
| |
|
|
| |
| if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: |
| self.run_compression(step=step) |
|
|
| if not cfg.disable_viewer: |
| self.viewer.lock.release() |
| num_train_steps_per_sec = 1.0 / (time.time() - tic) |
| num_train_rays_per_sec = ( |
| num_train_rays_per_step * num_train_steps_per_sec |
| ) |
| |
| self.viewer.render_tab_state.num_train_rays_per_sec = ( |
| num_train_rays_per_sec |
| ) |
| |
| self.viewer.update(step, num_train_rays_per_step) |
|
|
| @torch.no_grad() |
| def eval(self, step: int, stage: str = "val"): |
| """Entry for evaluation.""" |
| print("Running evaluation...") |
| cfg = self.cfg |
| device = self.device |
| world_rank = self.world_rank |
| world_size = self.world_size |
|
|
| valloader = torch.utils.data.DataLoader( |
| self.valset, batch_size=1, shuffle=False, num_workers=1 |
| ) |
| ellipse_time = 0 |
| metrics = defaultdict(list) |
| for i, data in enumerate(valloader): |
| camtoworlds = data["camtoworld"].to(device) |
| Ks = data["K"].to(device) |
| pixels = data["image"].to(device) / 255.0 |
| masks = data["mask"].to(device) if "mask" in data else None |
| height, width = pixels.shape[1:3] |
|
|
| torch.cuda.synchronize() |
| tic = time.time() |
| render_colors, _, _ = self.rasterize_splats( |
| camtoworlds=camtoworlds, |
| Ks=Ks, |
| width=width, |
| height=height, |
| sh_degree=cfg.sh_degree, |
| near_plane=cfg.near_plane, |
| far_plane=cfg.far_plane, |
| |
| render_mode="RGB+ED", |
| masks=masks, |
| ) |
| torch.cuda.synchronize() |
| ellipse_time += time.time() - tic |
|
|
| colors = render_colors[..., :3] |
| depths = render_colors[..., 3] |
|
|
| colors = torch.clamp(colors, 0.0, 1.0) |
| canvas_list = [pixels, colors] |
|
|
| if world_rank == 0: |
| |
| canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() |
| canvas = (canvas * 255).astype(np.uint8) |
| imageio.imwrite( |
| f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", |
| canvas, |
| ) |
| torchvision.utils.save_image(pixels.permute(0, 3, 1, 2), f"{self.render_dir}/gt_rgb_{stage}_step{step}_{i:04d}.png") |
| torchvision.utils.save_image(colors.permute(0, 3, 1, 2), f"{self.render_dir}/render_rgb_{stage}_step{step}_{i:04d}.png") |
| |
| |
|
|
| pixels_p = pixels.permute(0, 3, 1, 2) |
| colors_p = colors.permute(0, 3, 1, 2) |
| |
| metrics["psnr"].append(self.psnr(colors_p, pixels_p)) |
| metrics["ssim"].append(self.ssim(colors_p, pixels_p)) |
| metrics["lpips"].append(self.lpips(colors_p, pixels_p)) |
| if cfg.use_bilateral_grid: |
| cc_colors = color_correct(colors, pixels) |
| cc_colors_p = cc_colors.permute(0, 3, 1, 2) |
| metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) |
|
|
| if world_rank == 0: |
| ellipse_time /= len(valloader) |
|
|
| stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} |
| stats.update( |
| { |
| "ellipse_time": ellipse_time, |
| "num_GS": len(self.splats["means"]), |
| } |
| ) |
| print( |
| f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " |
| f"Time: {stats['ellipse_time']:.3f}s/image " |
| f"Number of GS: {stats['num_GS']}" |
| ) |
| |
| with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f: |
| json.dump(stats, f) |
| |
| for k, v in stats.items(): |
| self.writer.add_scalar(f"{stage}/{k}", v, step) |
| self.writer.flush() |
|
|
| @torch.no_grad() |
| def render_traj(self, step: int): |
| """Entry for trajectory rendering.""" |
| if self.cfg.disable_video: |
| return |
| print("Running trajectory rendering...") |
| cfg = self.cfg |
| device = self.device |
|
|
| camtoworlds_all = self.parser.camtoworlds[5:-5] |
| if cfg.render_traj_path == "interp": |
| camtoworlds_all = generate_interpolated_path( |
| camtoworlds_all, 1 |
| ) |
| elif cfg.render_traj_path == "ellipse": |
| height = camtoworlds_all[:, 2, 3].mean() |
| camtoworlds_all = generate_ellipse_path_z( |
| camtoworlds_all, height=height |
| ) |
| elif cfg.render_traj_path == "spiral": |
| camtoworlds_all = generate_spiral_path( |
| camtoworlds_all, |
| bounds=self.parser.bounds * self.scene_scale, |
| spiral_scale_r=self.parser.extconf["spiral_radius_scale"], |
| ) |
| else: |
| raise ValueError( |
| f"Render trajectory type not supported: {cfg.render_traj_path}" |
| ) |
|
|
| camtoworlds_all = np.concatenate( |
| [ |
| camtoworlds_all, |
| np.repeat( |
| np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 |
| ), |
| ], |
| axis=1, |
| ) |
|
|
| camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) |
| K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) |
| width, height = list(self.parser.imsize_dict.values())[0] |
|
|
| |
| video_dir = f"{cfg.result_dir}/videos" |
| os.makedirs(video_dir, exist_ok=True) |
| writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) |
| for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): |
| camtoworlds = camtoworlds_all[i : i + 1] |
| Ks = K[None] |
|
|
| renders, _, _ = self.rasterize_splats( |
| camtoworlds=camtoworlds, |
| Ks=Ks, |
| width=width, |
| height=height, |
| sh_degree=cfg.sh_degree, |
| near_plane=cfg.near_plane, |
| far_plane=cfg.far_plane, |
| render_mode="RGB+ED", |
| ) |
| colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) |
| depths = renders[..., 3:4] |
| depths = (depths - depths.min()) / (depths.max() - depths.min()) |
| canvas_list = [colors, depths.repeat(1, 1, 1, 3)] |
|
|
| |
| canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() |
| canvas = (canvas * 255).astype(np.uint8) |
| writer.append_data(canvas) |
| writer.close() |
| print(f"Video saved to {video_dir}/traj_{step}.mp4") |
|
|
| @torch.no_grad() |
| def run_compression(self, step: int): |
| """Entry for running compression.""" |
| print("Running compression...") |
| world_rank = self.world_rank |
|
|
| compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" |
| os.makedirs(compress_dir, exist_ok=True) |
|
|
| self.compression_method.compress(compress_dir, self.splats) |
|
|
| |
| splats_c = self.compression_method.decompress(compress_dir) |
| for k in splats_c.keys(): |
| self.splats[k].data = splats_c[k].to(self.device) |
| self.eval(step=step, stage="compress") |
|
|
| @torch.no_grad() |
| def _viewer_render_fn( |
| self, camera_state: CameraState, render_tab_state: RenderTabState |
| ): |
| assert isinstance(render_tab_state, GsplatRenderTabState) |
| if render_tab_state.preview_render: |
| width = render_tab_state.render_width |
| height = render_tab_state.render_height |
| else: |
| width = render_tab_state.viewer_width |
| height = render_tab_state.viewer_height |
| c2w = camera_state.c2w |
| K = camera_state.get_K((width, height)) |
| c2w = torch.from_numpy(c2w).float().to(self.device) |
| K = torch.from_numpy(K).float().to(self.device) |
|
|
| RENDER_MODE_MAP = { |
| "rgb": "RGB", |
| "depth(accumulated)": "D", |
| "depth(expected)": "ED", |
| "alpha": "RGB", |
| } |
|
|
| render_colors, render_alphas, info = self.rasterize_splats( |
| camtoworlds=c2w[None], |
| Ks=K[None], |
| width=width, |
| height=height, |
| sh_degree=min(render_tab_state.max_sh_degree, self.cfg.sh_degree), |
| near_plane=render_tab_state.near_plane, |
| far_plane=render_tab_state.far_plane, |
| radius_clip=render_tab_state.radius_clip, |
| |
| eps2d=render_tab_state.eps2d, |
| backgrounds=torch.tensor([render_tab_state.backgrounds], device=self.device) |
| / 255.0, |
| render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], |
| rasterize_mode=render_tab_state.rasterize_mode, |
| camera_model=render_tab_state.camera_model, |
| ) |
| render_tab_state.total_gs_count = len(self.splats["means"]) |
| render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() |
|
|
| if render_tab_state.render_mode == "rgb": |
| |
| render_colors = render_colors[0, ..., 0:3].clamp(0, 1) |
| renders = render_colors.cpu().numpy() |
| elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: |
| |
| depth = render_colors[0, ..., 0:1] |
| if render_tab_state.normalize_nearfar: |
| near_plane = render_tab_state.near_plane |
| far_plane = render_tab_state.far_plane |
| else: |
| near_plane = depth.min() |
| far_plane = depth.max() |
| depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10) |
| depth_norm = torch.clip(depth_norm, 0, 1) |
| if render_tab_state.inverse: |
| depth_norm = 1 - depth_norm |
| renders = ( |
| apply_float_colormap(depth_norm, render_tab_state.colormap) |
| .cpu() |
| .numpy() |
| ) |
| elif render_tab_state.render_mode == "alpha": |
| alpha = render_alphas[0, ..., 0:1] |
| if render_tab_state.inverse: |
| alpha = 1 - alpha |
| renders = ( |
| apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy() |
| ) |
| return renders |
|
|
|
|
| def main(local_rank: int, world_rank, world_size: int, cfg: Config): |
| if world_size > 1 and not cfg.disable_viewer: |
| cfg.disable_viewer = True |
| if world_rank == 0: |
| print("Viewer is disabled in distributed training.") |
|
|
| runner = Runner(local_rank, world_rank, world_size, cfg) |
|
|
| if cfg.ckpt is not None: |
| |
| ckpts = [ |
| torch.load(file, map_location=runner.device, weights_only=True) |
| for file in cfg.ckpt |
| ] |
| for k in runner.splats.keys(): |
| runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) |
| step = ckpts[0]["step"] |
| runner.eval(step=step) |
| |
| if cfg.compression is not None: |
| runner.run_compression(step=step) |
| else: |
| runner.train() |
| runner.eval(step=runner.cfg.max_steps) |
| |
| print("Training complete.") |
| |
| |
| |
| |
|
|
|
|
| if __name__ == "__main__": |
| """ |
| Usage: |
| |
| ```bash |
| # Single GPU training |
| CUDA_VISIBLE_DEVICES=9 python -m examples.simple_trainer default |
| |
| # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. |
| CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 |
| |
| """ |
|
|
| |
| |
| configs = { |
| "default": ( |
| "Gaussian splatting training using densification heuristics from the original paper.", |
| Config( |
| strategy=DefaultStrategy(verbose=True), |
| ), |
| ), |
| "mcmc": ( |
| "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", |
| Config( |
| init_opa=0.5, |
| init_scale=0.1, |
| opacity_reg=0.01, |
| scale_reg=0.01, |
| strategy=MCMCStrategy(verbose=True), |
| ), |
| ), |
| } |
| cfg = tyro.extras.overridable_config_cli(configs) |
| cfg.adjust_steps(cfg.steps_scaler) |
|
|
| |
| if cfg.compression == "png": |
| try: |
| import plas |
| import torchpq |
| except: |
| raise ImportError( |
| "To use PNG compression, you need to install " |
| "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " |
| "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " |
| ) |
|
|
| cli(main, cfg, verbose=True) |
|
|