Spaces:
Runtime error
Runtime error
| import json | |
| import math | |
| import os | |
| import time | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import imageio | |
| import matplotlib | |
| import torchvision | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import tqdm | |
| import tyro | |
| import viser | |
| import yaml | |
| import torchvision | |
| import sys | |
| from plyfile import PlyData, PlyElement | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri | |
| from src.model.types import Gaussians | |
| from src.post_opt.datasets.colmap import Dataset, Parser | |
| from src.post_opt.datasets.traj import ( | |
| generate_ellipse_path_z, | |
| generate_interpolated_path, | |
| generate_spiral_path, | |
| ) | |
| from fused_ssim import fused_ssim | |
| from src.utils.image import process_image | |
| from src.post_opt.exporter import export_splats | |
| from src.post_opt.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss | |
| from torch import Tensor | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure | |
| from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity | |
| from typing_extensions import Literal, assert_never | |
| from src.post_opt.utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed | |
| # from gsplat import export_splats | |
| from gsplat.compression import PngCompression | |
| from gsplat.distributed import cli | |
| # from gsplat.optimizers import SelectiveAdam | |
| # from gsplat.rendering import rasterization | |
| from gsplat import rasterization | |
| from gsplat.strategy import DefaultStrategy, MCMCStrategy | |
| from src.post_opt.gsplat_viewer import GsplatViewer, GsplatRenderTabState | |
| from nerfview import CameraState, RenderTabState, apply_float_colormap | |
| import torch | |
| from einops import rearrange | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from scipy.spatial.transform import Rotation as R | |
| from src.model.model.anysplat import AnySplat | |
| # pytorch3d/pytorch3d/transforms/rotation_conversions.py at main · facebookresearch/pytorch3d | |
| def quaternion_to_matrix( | |
| quaternions: Float[Tensor, "*batch 4"], | |
| eps: float = 1e-8, | |
| ) -> Float[Tensor, "*batch 3 3"]: | |
| # Order changed to match scipy format! | |
| i, j, k, r = torch.unbind(quaternions, dim=-1) | |
| two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) | |
| o = torch.stack( | |
| ( | |
| 1 - two_s * (j * j + k * k), | |
| two_s * (i * j - k * r), | |
| two_s * (i * k + j * r), | |
| two_s * (i * j + k * r), | |
| 1 - two_s * (i * i + k * k), | |
| two_s * (j * k - i * r), | |
| two_s * (i * k - j * r), | |
| two_s * (j * k + i * r), | |
| 1 - two_s * (i * i + j * j), | |
| ), | |
| -1, | |
| ) | |
| return rearrange(o, "... (i j) -> ... i j", i=3, j=3) | |
| def construct_list_of_attributes(num_rest: int) -> list[str]: | |
| attributes = ["x", "y", "z", "nx", "ny", "nz"] | |
| for i in range(3): | |
| attributes.append(f"f_dc_{i}") | |
| for i in range(num_rest): | |
| attributes.append(f"f_rest_{i}") | |
| attributes.append("opacity") | |
| for i in range(3): | |
| attributes.append(f"scale_{i}") | |
| for i in range(4): | |
| attributes.append(f"rot_{i}") | |
| return attributes | |
| def export_ply( | |
| means: Float[Tensor, "gaussian 3"], | |
| scales: Float[Tensor, "gaussian 3"], | |
| rotations: Float[Tensor, "gaussian 4"], | |
| harmonics: Float[Tensor, "gaussian 3 d_sh"], | |
| opacities: Float[Tensor, " gaussian"], | |
| path: Path, | |
| shift_and_scale: bool = False, | |
| save_sh_dc_only: bool = True, | |
| ): | |
| if shift_and_scale: | |
| # Shift the scene so that the median Gaussian is at the origin. | |
| means = means - means.median(dim=0).values | |
| # Rescale the scene so that most Gaussians are within range [-1, 1]. | |
| scale_factor = means.abs().quantile(0.95, dim=0).max() | |
| means = means / scale_factor | |
| scales = scales / scale_factor | |
| # Apply the rotation to the Gaussian rotations. | |
| rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() | |
| rotations = R.from_matrix(rotations).as_quat() | |
| x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") | |
| rotations = np.stack((w, x, y, z), axis=-1) | |
| # Since current model use SH_degree = 4, | |
| # which require large memory to store, we can only save the DC band to save memory. | |
| f_dc = harmonics[..., 0] | |
| f_rest = harmonics[..., 1:].flatten(start_dim=1) | |
| dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])] | |
| elements = np.empty(means.shape[0], dtype=dtype_full) | |
| attributes = [ | |
| means.detach().cpu().numpy(), | |
| torch.zeros_like(means).detach().cpu().numpy(), | |
| f_dc.detach().cpu().contiguous().numpy(), | |
| f_rest.detach().cpu().contiguous().numpy(), | |
| opacities[..., None].detach().cpu().numpy(), | |
| scales.detach().cpu().numpy(), | |
| rotations, | |
| ] | |
| if save_sh_dc_only: | |
| # remove f_rest from attributes | |
| attributes.pop(3) | |
| attributes = np.concatenate(attributes, axis=1) | |
| elements[:] = list(map(tuple, attributes)) | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| PlyData([PlyElement.describe(elements, "vertex")]).write(path) | |
| def colorize_depth_maps(depth_map, min_depth=0.0, max_depth=1.0, cmap="Spectral", valid_mask=None): | |
| """ | |
| Colorize depth maps. | |
| """ | |
| assert len(depth_map.shape) >= 2, "Invalid dimension" | |
| if isinstance(depth_map, torch.Tensor): | |
| depth = depth_map.detach().clone().squeeze().numpy() | |
| elif isinstance(depth_map, np.ndarray): | |
| depth = depth_map.copy().squeeze() | |
| # reshape to [ (B,) H, W ] | |
| if depth.ndim < 3: | |
| depth = depth[np.newaxis, :, :] | |
| # colorize | |
| cm = matplotlib.colormaps[cmap] | |
| # depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) | |
| depth = ((depth - depth.min()) / (depth.max() - depth.min())).clip(0, 1) | |
| img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 | |
| img_colored_np = np.rollaxis(img_colored_np, 3, 1) | |
| if valid_mask is not None: | |
| if isinstance(depth_map, torch.Tensor): | |
| valid_mask = valid_mask.detach().numpy() | |
| valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] | |
| if valid_mask.ndim < 3: | |
| valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] | |
| else: | |
| valid_mask = valid_mask[:, np.newaxis, :, :] | |
| valid_mask = np.repeat(valid_mask, 3, axis=1) | |
| img_colored_np[~valid_mask] = 0 | |
| if isinstance(depth_map, torch.Tensor): | |
| img_colored = torch.from_numpy(img_colored_np).float() | |
| elif isinstance(depth_map, np.ndarray): | |
| img_colored = img_colored_np | |
| return img_colored | |
| def build_covariance( | |
| scale: Float[Tensor, "*#batch 3"], | |
| rotation_xyzw: Float[Tensor, "*#batch 4"], | |
| ) -> Float[Tensor, "*batch 3 3"]: | |
| scale = scale.diag_embed() | |
| rotation = quaternion_to_matrix(rotation_xyzw) | |
| return ( | |
| rotation | |
| ) | |
| class Config: | |
| # Disable viewer | |
| disable_viewer: bool = True | |
| # Path to the .pt files. If provide, it will skip training and run evaluation only. | |
| ckpt: Optional[List[str]] = None | |
| # Name of compression strategy to use | |
| compression: Optional[Literal["png"]] = None | |
| # Render trajectory path | |
| render_traj_path: str = "interp" | |
| data_dir: str = "data/360_v2/garden" | |
| # Downsample factor for the dataset | |
| data_factor: int = 4 | |
| # Directory to save results | |
| result_dir: str = "results/garden" | |
| # Every N images there is a test image | |
| test_every: int = 8 | |
| # Random crop size for training (experimental) | |
| patch_size: Optional[int] = None | |
| # A global scaler that applies to the scene size related parameters | |
| global_scale: float = 1.0 | |
| # Normalize the world space | |
| normalize_world_space: bool = True | |
| # Camera model | |
| camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" | |
| # Port for the viewer server | |
| port: int = 8080 | |
| # Batch size for training. Learning rates are scaled automatically | |
| batch_size: int = 1 | |
| # A global factor to scale the number of training steps | |
| steps_scaler: float = 1.0 | |
| # Number of training steps | |
| max_steps: int = 3_000 | |
| # Steps to evaluate the model | |
| eval_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) | |
| # Steps to save the model | |
| save_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) | |
| # Whether to save ply file (storage size can be large) | |
| save_ply: bool = False | |
| # Steps to save the model as ply | |
| ply_steps: List[int] = field(default_factory=lambda: [1, 1_000, 3_000, 7_000, 10_000]) | |
| # Whether to disable video generation during training and evaluation | |
| disable_video: bool = False | |
| # Initialization strategy | |
| init_type: str = "sfm" | |
| # Initial number of GSs. Ignored if using sfm | |
| init_num_pts: int = 100_000 | |
| # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm | |
| init_extent: float = 3.0 | |
| # Degree of spherical harmonics | |
| sh_degree: int = 4 | |
| # Turn on another SH degree every this steps | |
| sh_degree_interval: int = 1000 | |
| # Initial opacity of GS | |
| init_opa: float = 0.1 | |
| # Initial scale of GS | |
| init_scale: float = 1.0 | |
| # Weight for SSIM loss | |
| ssim_lambda: float = 0.2 | |
| # Near plane clipping distance | |
| near_plane: float = 1e-10 | |
| # Far plane clipping distance | |
| far_plane: float = 1e10 | |
| # Strategy for GS densification | |
| strategy: Union[DefaultStrategy, MCMCStrategy] = field( | |
| default_factory=DefaultStrategy | |
| ) | |
| # Use packed mode for rasterization, this leads to less memory usage but slightly slower. | |
| packed: bool = False | |
| # Use sparse gradients for optimization. (experimental) | |
| sparse_grad: bool = False | |
| # Use visible adam from Taming 3DGS. (experimental) | |
| visible_adam: bool = False | |
| # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. | |
| antialiased: bool = False | |
| # Use random background for training to discourage transparency | |
| random_bkgd: bool = False | |
| # Opacity regularization | |
| opacity_reg: float = 0.0 | |
| # Scale regularization | |
| scale_reg: float = 0.0 | |
| # Enable camera optimization. | |
| pose_opt: bool = True | |
| # Learning rate for camera optimization | |
| pose_opt_lr: float = 1e-5 | |
| # Regularization for camera optimization as weight decay | |
| pose_opt_reg: float = 1e-6 | |
| # Add noise to camera extrinsics. This is only to test the camera pose optimization. | |
| pose_noise: float = 0.0 | |
| # Enable appearance optimization. (experimental) | |
| app_opt: bool = False | |
| # Appearance embedding dimension | |
| app_embed_dim: int = 16 | |
| # Learning rate for appearance optimization | |
| app_opt_lr: float = 1e-3 | |
| # Regularization for appearance optimization as weight decay | |
| app_opt_reg: float = 1e-6 | |
| # Enable bilateral grid. (experimental) | |
| use_bilateral_grid: bool = False | |
| # Shape of the bilateral grid (X, Y, W) | |
| bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) | |
| # Enable depth loss. (experimental) | |
| depth_loss: bool = False | |
| # Weight for depth loss | |
| depth_lambda: float = 1e-2 | |
| # Dump information to tensorboard every this steps | |
| tb_every: int = 100 | |
| # Save training images to tensorboard | |
| tb_save_image: bool = False | |
| lpips_net: Literal["vgg", "alex"] = "vgg" | |
| lr_means: float = 1.6e-4 | |
| lr_scales: float = 5e-3 | |
| lr_quats: float = 1e-3 | |
| lr_opacities: float = 5e-2 | |
| lr_sh: float = 2.5e-3 | |
| def adjust_steps(self, factor: float): | |
| self.eval_steps = [int(i * factor) for i in self.eval_steps] | |
| self.save_steps = [int(i * factor) for i in self.save_steps] | |
| self.ply_steps = [int(i * factor) for i in self.ply_steps] | |
| self.max_steps = int(self.max_steps * factor) | |
| self.sh_degree_interval = int(self.sh_degree_interval * factor) | |
| strategy = self.strategy | |
| if isinstance(strategy, DefaultStrategy): | |
| # strategy.refine_start_iter = int(strategy.refine_start_iter * factor) | |
| # strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) | |
| # strategy.reset_every = int(strategy.reset_every * factor) | |
| # strategy.refine_every = int(strategy.refine_every * factor) | |
| strategy.refine_start_iter = 30000 | |
| strategy.refine_stop_iter = 0 | |
| strategy.reset_every = 30000 | |
| strategy.refine_every = 30000 | |
| elif isinstance(strategy, MCMCStrategy): | |
| strategy.refine_start_iter = int(strategy.refine_start_iter * factor) | |
| strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) | |
| strategy.refine_every = int(strategy.refine_every * factor) | |
| else: | |
| assert_never(strategy) | |
| def create_splats_with_optimizers( | |
| gaussians: Gaussians, | |
| init_num_pts: int = 100_000, | |
| init_extent: float = 3.0, | |
| init_opacity: float = 0.1, | |
| init_scale: float = 1.0, | |
| sh_degree: int = 3, | |
| sparse_grad: bool = False, | |
| visible_adam: bool = False, | |
| batch_size: int = 1, | |
| feature_dim: Optional[int] = None, | |
| device: str = "cuda", | |
| world_rank: int = 0, | |
| world_size: int = 1, | |
| cfg: Config = None, | |
| ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: | |
| points = gaussians.means[0].detach().float() | |
| scales = torch.log(gaussians.scales[0].detach().float()) | |
| quats = gaussians.rotations[0].detach().float() | |
| opacities = torch.logit(gaussians.opacities[0].detach().float()) | |
| harmonics = gaussians.harmonics[0].detach().float().permute(0, 2, 1).contiguous() | |
| N = points.shape[0] | |
| scene_scale = 1.0 | |
| masks = opacities.sigmoid() > 0.01 | |
| harmonics = harmonics[masks] | |
| params = [ | |
| # name, value, lr | |
| ("means", torch.nn.Parameter(points[masks]), cfg.lr_means * scene_scale), | |
| ("scales", torch.nn.Parameter(scales[masks]), cfg.lr_scales), | |
| ("quats", torch.nn.Parameter(quats[masks]), cfg.lr_quats), | |
| ("opacities", torch.nn.Parameter(opacities[masks]), cfg.lr_opacities), | |
| ] | |
| params.append(("sh0", torch.nn.Parameter(harmonics[:, :1, :]), cfg.lr_sh)) | |
| params.append(("shN", torch.nn.Parameter(harmonics[:, 1:, :]), cfg.lr_sh/20)) | |
| splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) | |
| # Scale learning rate based on batch size, reference: | |
| # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ | |
| # Note that this would not make the training exactly equivalent, see | |
| # https://arxiv.org/pdf/2402.18824v1 | |
| BS = batch_size * world_size | |
| optimizer_class = None | |
| if sparse_grad: | |
| optimizer_class = torch.optim.SparseAdam | |
| elif visible_adam: | |
| optimizer_class = SelectiveAdam | |
| else: | |
| optimizer_class = torch.optim.Adam | |
| optimizers = { | |
| name: optimizer_class( | |
| [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], | |
| eps=1e-15 / math.sqrt(BS), | |
| # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. | |
| betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), | |
| ) | |
| for name, _, lr in params | |
| } | |
| return splats, optimizers | |
| class Runner: | |
| """Engine for training and testing.""" | |
| def __init__( | |
| self, local_rank: int, world_rank, world_size: int, cfg: Config | |
| ) -> None: | |
| set_random_seed(42 + local_rank) | |
| self.cfg = cfg | |
| self.world_rank = world_rank | |
| self.local_rank = local_rank | |
| self.world_size = world_size | |
| self.device = f"cuda:{local_rank}" | |
| # Where to dump results. | |
| os.makedirs(cfg.result_dir, exist_ok=True) | |
| # Setup output directories. | |
| self.ckpt_dir = f"{cfg.result_dir}/ckpts" | |
| os.makedirs(self.ckpt_dir, exist_ok=True) | |
| self.stats_dir = f"{cfg.result_dir}/stats" | |
| os.makedirs(self.stats_dir, exist_ok=True) | |
| self.render_dir = f"{cfg.result_dir}/renders" | |
| os.makedirs(self.render_dir, exist_ok=True) | |
| self.ply_dir = f"{cfg.result_dir}/ply" | |
| os.makedirs(self.ply_dir, exist_ok=True) | |
| # Tensorboard | |
| self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") | |
| # first get the initial 3DGS and camera poses | |
| model = AnySplat.from_pretrained("lhjiang/anysplat") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| image_folder = cfg.data_dir | |
| image_names = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) | |
| images = [process_image(img_path) for img_path in image_names] | |
| ctx_indices = [idx for idx, name in enumerate(image_names) if idx % cfg.test_every != 0] | |
| tgt_indices = [idx for idx, name in enumerate(image_names) if idx % cfg.test_every == 0] | |
| ctx_images = torch.stack([images[i] for i in ctx_indices], dim=0).unsqueeze(0).to(device) | |
| tgt_images = torch.stack([images[i] for i in tgt_indices], dim=0).unsqueeze(0).to(device) | |
| ctx_images = (ctx_images+1)*0.5 | |
| tgt_images = (tgt_images+1)*0.5 | |
| b, v, _, h, w = tgt_images.shape | |
| # run inference | |
| encoder_output = model.encoder( | |
| ctx_images, | |
| global_step=0, | |
| visualization_dump={}, | |
| ) | |
| gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose | |
| num_context_view = ctx_images.shape[1] | |
| vggt_input_image = torch.cat((ctx_images, tgt_images), dim=1).to(torch.bfloat16) | |
| with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): | |
| aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(vggt_input_image, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) | |
| with torch.cuda.amp.autocast(enabled=False): | |
| fp32_tokens = [token.float() for token in aggregated_tokens_list] | |
| pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] | |
| pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, vggt_input_image.shape[-2:]) | |
| extrinsic_padding = torch.tensor([0, 0, 0, 1], device=pred_all_extrinsic.device, dtype=pred_all_extrinsic.dtype).view(1, 1, 1, 4).repeat(b, vggt_input_image.shape[1], 1, 1) | |
| pred_all_extrinsic = torch.cat([pred_all_extrinsic, extrinsic_padding], dim=2).inverse() | |
| pred_all_intrinsic[:, :, 0] = pred_all_intrinsic[:, :, 0] / w | |
| pred_all_intrinsic[:, :, 1] = pred_all_intrinsic[:, :, 1] / h | |
| pred_all_context_extrinsic, pred_all_target_extrinsic = pred_all_extrinsic[:, :num_context_view], pred_all_extrinsic[:, num_context_view:] | |
| pred_all_context_intrinsic, pred_all_target_intrinsic = pred_all_intrinsic[:, :num_context_view], pred_all_intrinsic[:, num_context_view:] | |
| scale_factor = pred_context_pose['extrinsic'][:, :, :3, 3].mean() / pred_all_context_extrinsic[:, :, :3, 3].mean() | |
| pred_all_target_extrinsic[..., :3, 3] = pred_all_target_extrinsic[..., :3, 3] * scale_factor | |
| pred_all_context_extrinsic[..., :3, 3] = pred_all_context_extrinsic[..., :3, 3] * scale_factor | |
| print("scale_factor:", scale_factor) | |
| # Load data: Training data should contain initial points and colors. | |
| # self.parser = Parser( | |
| # data_dir=cfg.data_dir, | |
| # factor=cfg.data_factor, | |
| # normalize=cfg.normalize_world_space, | |
| # test_every=cfg.test_every, | |
| # ) | |
| self.trainset = Dataset( | |
| split="train", | |
| images=ctx_images[0].detach().cpu().numpy(), | |
| camtoworlds=pred_all_context_extrinsic[0].detach().cpu().numpy(), | |
| Ks=pred_all_context_intrinsic[0].detach().cpu().numpy(), | |
| patch_size=cfg.patch_size, | |
| load_depths=cfg.depth_loss, | |
| ) | |
| self.valset = Dataset( | |
| images=tgt_images[0].detach().cpu().numpy(), | |
| camtoworlds=pred_all_target_extrinsic[0].detach().cpu().numpy(), | |
| Ks=pred_all_target_intrinsic[0].detach().cpu().numpy(), | |
| split="val" | |
| ) | |
| # Model | |
| feature_dim = 32 if cfg.app_opt else None | |
| self.splats, self.optimizers = create_splats_with_optimizers( | |
| gaussians=gaussians, | |
| init_num_pts=cfg.init_num_pts, | |
| init_extent=cfg.init_extent, | |
| init_opacity=cfg.init_opa, | |
| init_scale=cfg.init_scale, | |
| sh_degree=cfg.sh_degree, | |
| sparse_grad=cfg.sparse_grad, | |
| visible_adam=cfg.visible_adam, | |
| batch_size=cfg.batch_size, | |
| feature_dim=feature_dim, | |
| device=self.device, | |
| world_rank=world_rank, | |
| world_size=world_size, | |
| cfg=cfg, | |
| ) | |
| print("Model initialized. Number of GS:", len(self.splats["means"])) | |
| # Densification Strategy | |
| self.cfg.strategy.check_sanity(self.splats, self.optimizers) | |
| if isinstance(self.cfg.strategy, DefaultStrategy): | |
| self.strategy_state = self.cfg.strategy.initialize_state( | |
| scene_scale=1.0 | |
| ) | |
| elif isinstance(self.cfg.strategy, MCMCStrategy): | |
| self.strategy_state = self.cfg.strategy.initialize_state() | |
| else: | |
| assert_never(self.cfg.strategy) | |
| # Compression Strategy | |
| self.compression_method = None | |
| if cfg.compression is not None: | |
| if cfg.compression == "png": | |
| self.compression_method = PngCompression() | |
| else: | |
| raise ValueError(f"Unknown compression strategy: {cfg.compression}") | |
| self.pose_optimizers = [] | |
| if cfg.pose_opt: | |
| self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) | |
| self.pose_adjust.zero_init() | |
| self.pose_optimizers = [ | |
| torch.optim.Adam( | |
| self.pose_adjust.parameters(), | |
| lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), | |
| weight_decay=cfg.pose_opt_reg, | |
| ) | |
| ] | |
| if world_size > 1: | |
| self.pose_adjust = DDP(self.pose_adjust) | |
| if cfg.pose_noise > 0.0: | |
| self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) | |
| self.pose_perturb.random_init(cfg.pose_noise) | |
| if world_size > 1: | |
| self.pose_perturb = DDP(self.pose_perturb) | |
| self.app_optimizers = [] | |
| if cfg.app_opt: | |
| assert feature_dim is not None | |
| self.app_module = AppearanceOptModule( | |
| len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree | |
| ).to(self.device) | |
| # initialize the last layer to be zero so that the initial output is zero. | |
| torch.nn.init.zeros_(self.app_module.color_head[-1].weight) | |
| torch.nn.init.zeros_(self.app_module.color_head[-1].bias) | |
| self.app_optimizers = [ | |
| torch.optim.Adam( | |
| self.app_module.embeds.parameters(), | |
| lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, | |
| weight_decay=cfg.app_opt_reg, | |
| ), | |
| torch.optim.Adam( | |
| self.app_module.color_head.parameters(), | |
| lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), | |
| ), | |
| ] | |
| if world_size > 1: | |
| self.app_module = DDP(self.app_module) | |
| self.bil_grid_optimizers = [] | |
| if cfg.use_bilateral_grid: | |
| self.bil_grids = BilateralGrid( | |
| len(self.trainset), | |
| grid_X=cfg.bilateral_grid_shape[0], | |
| grid_Y=cfg.bilateral_grid_shape[1], | |
| grid_W=cfg.bilateral_grid_shape[2], | |
| ).to(self.device) | |
| self.bil_grid_optimizers = [ | |
| torch.optim.Adam( | |
| self.bil_grids.parameters(), | |
| lr=2e-3 * math.sqrt(cfg.batch_size), | |
| eps=1e-15, | |
| ), | |
| ] | |
| # Losses & Metrics. | |
| self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) | |
| self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) | |
| if cfg.lpips_net == "alex": | |
| self.lpips = LearnedPerceptualImagePatchSimilarity( | |
| net_type="alex", normalize=True | |
| ).to(self.device) | |
| elif cfg.lpips_net == "vgg": | |
| # The 3DGS official repo uses lpips vgg, which is equivalent with the following: | |
| self.lpips = LearnedPerceptualImagePatchSimilarity( | |
| net_type="vgg", normalize=False | |
| ).to(self.device) | |
| else: | |
| raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") | |
| # Viewer | |
| if not self.cfg.disable_viewer: | |
| self.server = viser.ViserServer(port=cfg.port, verbose=False) | |
| self.viewer = GsplatViewer( | |
| server=self.server, | |
| render_fn=self._viewer_render_fn, | |
| output_dir=Path(cfg.result_dir), | |
| mode="training", | |
| ) | |
| def rasterize_splats( | |
| self, | |
| camtoworlds: Tensor, | |
| Ks: Tensor, | |
| width: int, | |
| height: int, | |
| masks: Optional[Tensor] = None, | |
| rasterize_mode: Optional[Literal["classic", "antialiased"]] = None, | |
| camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None, | |
| **kwargs, | |
| ) -> Tuple[Tensor, Tensor, Dict]: | |
| means = self.splats["means"] # [N, 3] | |
| # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] | |
| # rasterization does normalization internally | |
| quats = self.splats["quats"] # [N, 4] | |
| scales = torch.exp(self.splats["scales"]) # [N, 3] | |
| opacities = torch.sigmoid(self.splats["opacities"]) # [N,] | |
| image_ids = kwargs.pop("image_ids", None) | |
| if self.cfg.app_opt: | |
| colors = self.app_module( | |
| features=self.splats["features"], | |
| embed_ids=image_ids, | |
| dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], | |
| sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), | |
| ) | |
| colors = colors + self.splats["colors"] | |
| colors = torch.sigmoid(colors) | |
| else: | |
| colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] | |
| if rasterize_mode is None: | |
| rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" | |
| if camera_model is None: | |
| camera_model = self.cfg.camera_model | |
| # covariance = build_covariance(scales[None], quats[None]).squeeze(0) | |
| render_colors, render_alphas, info = rasterization( | |
| means=means, | |
| quats=quats, | |
| scales=scales, | |
| opacities=opacities, | |
| colors=colors, | |
| # covars=covariance, | |
| viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] | |
| Ks=Ks, # [C, 3, 3] | |
| width=width, | |
| height=height, | |
| packed=self.cfg.packed, | |
| absgrad=( | |
| self.cfg.strategy.absgrad | |
| if isinstance(self.cfg.strategy, DefaultStrategy) | |
| else False | |
| ), | |
| sparse_grad=self.cfg.sparse_grad, | |
| rasterize_mode=rasterize_mode, | |
| distributed=self.world_size > 1, | |
| camera_model=self.cfg.camera_model, | |
| radius_clip=0.1, | |
| backgrounds=torch.tensor([0.0, 0.0, 0.0]).cuda().unsqueeze(0).repeat(1, 1), | |
| **kwargs, | |
| ) | |
| if masks is not None: | |
| render_colors[~masks] = 0 | |
| return render_colors, render_alphas, info | |
| def train(self): | |
| cfg = self.cfg | |
| device = self.device | |
| world_rank = self.world_rank | |
| world_size = self.world_size | |
| # Dump cfg. | |
| if world_rank == 0: | |
| with open(f"{cfg.result_dir}/cfg.yml", "w") as f: | |
| yaml.dump(vars(cfg), f) | |
| max_steps = cfg.max_steps | |
| init_step = 0 | |
| schedulers = [ | |
| # means has a learning rate schedule, that end at 0.01 of the initial value | |
| torch.optim.lr_scheduler.ExponentialLR( | |
| self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) | |
| ), | |
| ] | |
| if cfg.pose_opt: | |
| # pose optimization has a learning rate schedule | |
| schedulers.append( | |
| torch.optim.lr_scheduler.ExponentialLR( | |
| self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) | |
| ) | |
| ) | |
| if cfg.use_bilateral_grid: | |
| # bilateral grid has a learning rate schedule. Linear warmup for 1000 steps. | |
| schedulers.append( | |
| torch.optim.lr_scheduler.ChainedScheduler( | |
| [ | |
| torch.optim.lr_scheduler.LinearLR( | |
| self.bil_grid_optimizers[0], | |
| start_factor=0.01, | |
| total_iters=1000, | |
| ), | |
| torch.optim.lr_scheduler.ExponentialLR( | |
| self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps) | |
| ), | |
| ] | |
| ) | |
| ) | |
| trainloader = torch.utils.data.DataLoader( | |
| self.trainset, | |
| batch_size=cfg.batch_size, | |
| shuffle=True, | |
| num_workers=4, | |
| persistent_workers=True, | |
| pin_memory=True, | |
| ) | |
| trainloader_iter = iter(trainloader) | |
| # Training loop. | |
| global_tic = time.time() | |
| pbar = tqdm.tqdm(range(init_step, max_steps)) | |
| for step in pbar: | |
| if not cfg.disable_viewer: | |
| while self.viewer.state == "paused": | |
| time.sleep(0.01) | |
| self.viewer.lock.acquire() | |
| tic = time.time() | |
| try: | |
| data = next(trainloader_iter) | |
| except StopIteration: | |
| trainloader_iter = iter(trainloader) | |
| data = next(trainloader_iter) | |
| camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] | |
| Ks = data["K"].to(device) # [1, 3, 3] | |
| pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] | |
| num_train_rays_per_step = ( | |
| pixels.shape[0] * pixels.shape[1] * pixels.shape[2] | |
| ) | |
| image_ids = data["image_id"].to(device) | |
| masks = data["mask"].to(device) if "mask" in data else None # [1, H, W] | |
| if cfg.depth_loss: | |
| points = data["points"].to(device) # [1, M, 2] | |
| depths_gt = data["depths"].to(device) # [1, M] | |
| height, width = pixels.shape[1:3] | |
| if cfg.pose_noise: | |
| camtoworlds = self.pose_perturb(camtoworlds, image_ids) | |
| if cfg.pose_opt: | |
| camtoworlds = self.pose_adjust(camtoworlds, image_ids) | |
| # sh schedule | |
| # sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree) | |
| sh_degree_to_use = cfg.sh_degree | |
| # forward | |
| renders, alphas, info = self.rasterize_splats( | |
| camtoworlds=camtoworlds, | |
| Ks=Ks, | |
| width=width, | |
| height=height, | |
| sh_degree=sh_degree_to_use, | |
| near_plane=cfg.near_plane, | |
| far_plane=cfg.far_plane, | |
| image_ids=image_ids, | |
| render_mode="RGB+ED" if cfg.depth_loss else "RGB", | |
| masks=masks, | |
| ) | |
| if renders.shape[-1] == 4: | |
| colors, depths = renders[..., 0:3], renders[..., 3:4] | |
| else: | |
| colors, depths = renders, None | |
| if cfg.use_bilateral_grid: | |
| grid_y, grid_x = torch.meshgrid( | |
| (torch.arange(height, device=self.device) + 0.5) / height, | |
| (torch.arange(width, device=self.device) + 0.5) / width, | |
| indexing="ij", | |
| ) | |
| grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) | |
| colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] | |
| if cfg.random_bkgd: | |
| bkgd = torch.rand(1, 3, device=device) | |
| colors = colors + bkgd * (1.0 - alphas) | |
| self.cfg.strategy.step_pre_backward( | |
| params=self.splats, | |
| optimizers=self.optimizers, | |
| state=self.strategy_state, | |
| step=step, | |
| info=info, | |
| ) | |
| # loss | |
| l1loss = F.l1_loss(colors, pixels) | |
| ssimloss = 1.0 - fused_ssim( | |
| colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" | |
| ) | |
| loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda | |
| if cfg.depth_loss: | |
| # query depths from depth map | |
| points = torch.stack( | |
| [ | |
| points[:, :, 0] / (width - 1) * 2 - 1, | |
| points[:, :, 1] / (height - 1) * 2 - 1, | |
| ], | |
| dim=-1, | |
| ) # normalize to [-1, 1] | |
| grid = points.unsqueeze(2) # [1, M, 1, 2] | |
| depths = F.grid_sample( | |
| depths.permute(0, 3, 1, 2), grid, align_corners=True | |
| ) # [1, 1, M, 1] | |
| depths = depths.squeeze(3).squeeze(1) # [1, M] | |
| # calculate loss in disparity space | |
| disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) | |
| disp_gt = 1.0 / depths_gt # [1, M] | |
| depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale | |
| loss += depthloss * cfg.depth_lambda | |
| if cfg.use_bilateral_grid: | |
| tvloss = 10 * total_variation_loss(self.bil_grids.grids) | |
| loss += tvloss | |
| # regularizations | |
| if cfg.opacity_reg > 0.0: | |
| loss = ( | |
| loss | |
| + cfg.opacity_reg | |
| * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() | |
| ) | |
| if cfg.scale_reg > 0.0: | |
| loss = ( | |
| loss | |
| + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() | |
| ) | |
| loss.backward() | |
| desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " | |
| if cfg.depth_loss: | |
| desc += f"depth loss={depthloss.item():.6f}| " | |
| if cfg.pose_opt and cfg.pose_noise: | |
| # monitor the pose error if we inject noise | |
| pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) | |
| desc += f"pose err={pose_err.item():.6f}| " | |
| pbar.set_description(desc) | |
| # write images (gt and render) | |
| # if world_rank == 0 and step % 800 == 0: | |
| # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() | |
| # canvas = canvas.reshape(-1, *canvas.shape[2:]) | |
| # imageio.imwrite( | |
| # f"{self.render_dir}/train_rank{self.world_rank}.png", | |
| # (canvas * 255).astype(np.uint8), | |
| # ) | |
| if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: | |
| mem = torch.cuda.max_memory_allocated() / 1024**3 | |
| self.writer.add_scalar("train/loss", loss.item(), step) | |
| self.writer.add_scalar("train/l1loss", l1loss.item(), step) | |
| self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) | |
| self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step) | |
| self.writer.add_scalar("train/mem", mem, step) | |
| if cfg.depth_loss: | |
| self.writer.add_scalar("train/depthloss", depthloss.item(), step) | |
| if cfg.use_bilateral_grid: | |
| self.writer.add_scalar("train/tvloss", tvloss.item(), step) | |
| if cfg.tb_save_image: | |
| canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() | |
| canvas = canvas.reshape(-1, *canvas.shape[2:]) | |
| self.writer.add_image("train/render", canvas, step) | |
| self.writer.flush() | |
| # save checkpoint before updating the model | |
| if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: | |
| mem = torch.cuda.max_memory_allocated() / 1024**3 | |
| stats = { | |
| "mem": mem, | |
| "ellipse_time": time.time() - global_tic, | |
| "num_GS": len(self.splats["means"]), | |
| } | |
| print("Step: ", step, stats) | |
| with open( | |
| f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", | |
| "w", | |
| ) as f: | |
| json.dump(stats, f) | |
| data = {"step": step, "splats": self.splats.state_dict()} | |
| if cfg.pose_opt: | |
| if world_size > 1: | |
| data["pose_adjust"] = self.pose_adjust.module.state_dict() | |
| else: | |
| data["pose_adjust"] = self.pose_adjust.state_dict() | |
| if cfg.app_opt: | |
| if world_size > 1: | |
| data["app_module"] = self.app_module.module.state_dict() | |
| else: | |
| data["app_module"] = self.app_module.state_dict() | |
| torch.save( | |
| data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" | |
| ) | |
| if ( | |
| step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1 | |
| ) and cfg.save_ply: | |
| if self.cfg.app_opt: | |
| # eval at origin to bake the appeareance into the colors | |
| rgb = self.app_module( | |
| features=self.splats["features"], | |
| embed_ids=None, | |
| dirs=torch.zeros_like(self.splats["means"][None, :, :]), | |
| sh_degree=sh_degree_to_use, | |
| ) | |
| rgb = rgb + self.splats["colors"] | |
| rgb = torch.sigmoid(rgb).squeeze(0).unsqueeze(1) | |
| sh0 = rgb_to_sh(rgb) | |
| shN = torch.empty([sh0.shape[0], 0, 3], device=sh0.device) | |
| else: | |
| sh0 = self.splats["sh0"] | |
| shN = self.splats["shN"] | |
| # shN = torch.empty([sh0.shape[0], 0, 3], device=sh0.device) | |
| means = self.splats["means"] | |
| scales = self.splats["scales"] | |
| quats = self.splats["quats"] | |
| opacities = self.splats["opacities"] | |
| # export_splats( | |
| # means=means, | |
| # scales=scales, | |
| # quats=quats, | |
| # opacities=opacities, | |
| # sh0=sh0, | |
| # shN=shN, | |
| # format="ply", | |
| # save_to=f"{self.ply_dir}/point_cloud_{step}.ply", | |
| # ) | |
| export_ply( | |
| means=means, | |
| scales=scales, | |
| rotations=quats, | |
| harmonics=torch.cat([sh0, shN], dim=1).permute(0, 2, 1), | |
| opacities=opacities.sigmoid(), | |
| path=Path(f"{self.ply_dir}/point_cloud_{step}.ply"), | |
| ) | |
| # Turn Gradients into Sparse Tensor before running optimizer | |
| if cfg.sparse_grad: | |
| assert cfg.packed, "Sparse gradients only work with packed mode." | |
| gaussian_ids = info["gaussian_ids"] | |
| for k in self.splats.keys(): | |
| grad = self.splats[k].grad | |
| if grad is None or grad.is_sparse: | |
| continue | |
| self.splats[k].grad = torch.sparse_coo_tensor( | |
| indices=gaussian_ids[None], # [1, nnz] | |
| values=grad[gaussian_ids], # [nnz, ...] | |
| size=self.splats[k].size(), # [N, ...] | |
| is_coalesced=len(Ks) == 1, | |
| ) | |
| if cfg.visible_adam: | |
| gaussian_cnt = self.splats.means.shape[0] | |
| if cfg.packed: | |
| visibility_mask = torch.zeros_like( | |
| self.splats["opacities"], dtype=bool | |
| ) | |
| visibility_mask.scatter_(0, info["gaussian_ids"], 1) | |
| else: | |
| visibility_mask = (info["radii"] > 0).all(-1).any(0) | |
| # optimize | |
| for optimizer in self.optimizers.values(): | |
| if cfg.visible_adam: | |
| optimizer.step(visibility_mask) | |
| else: | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| for optimizer in self.pose_optimizers: | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| for optimizer in self.app_optimizers: | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| for optimizer in self.bil_grid_optimizers: | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| for scheduler in schedulers: | |
| scheduler.step() | |
| # Run post-backward steps after backward and optimizer | |
| if isinstance(self.cfg.strategy, DefaultStrategy): | |
| self.cfg.strategy.step_post_backward( | |
| params=self.splats, | |
| optimizers=self.optimizers, | |
| state=self.strategy_state, | |
| step=step, | |
| info=info, | |
| packed=cfg.packed, | |
| ) | |
| elif isinstance(self.cfg.strategy, MCMCStrategy): | |
| self.cfg.strategy.step_post_backward( | |
| params=self.splats, | |
| optimizers=self.optimizers, | |
| state=self.strategy_state, | |
| step=step, | |
| info=info, | |
| lr=schedulers[0].get_last_lr()[0], | |
| ) | |
| else: | |
| assert_never(self.cfg.strategy) | |
| # eval the full set | |
| if step in [i - 1 for i in cfg.eval_steps]: | |
| self.eval(step) | |
| # self.render_traj(step) | |
| # run compression | |
| if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: | |
| self.run_compression(step=step) | |
| if not cfg.disable_viewer: | |
| self.viewer.lock.release() | |
| num_train_steps_per_sec = 1.0 / (time.time() - tic) | |
| num_train_rays_per_sec = ( | |
| num_train_rays_per_step * num_train_steps_per_sec | |
| ) | |
| # Update the viewer state. | |
| self.viewer.render_tab_state.num_train_rays_per_sec = ( | |
| num_train_rays_per_sec | |
| ) | |
| # Update the scene. | |
| self.viewer.update(step, num_train_rays_per_step) | |
| def eval(self, step: int, stage: str = "val"): | |
| """Entry for evaluation.""" | |
| print("Running evaluation...") | |
| cfg = self.cfg | |
| device = self.device | |
| world_rank = self.world_rank | |
| world_size = self.world_size | |
| valloader = torch.utils.data.DataLoader( | |
| self.valset, batch_size=1, shuffle=False, num_workers=1 | |
| ) | |
| ellipse_time = 0 | |
| metrics = defaultdict(list) | |
| for i, data in enumerate(valloader): | |
| camtoworlds = data["camtoworld"].to(device) | |
| Ks = data["K"].to(device) | |
| pixels = data["image"].to(device) / 255.0 | |
| masks = data["mask"].to(device) if "mask" in data else None | |
| height, width = pixels.shape[1:3] | |
| torch.cuda.synchronize() | |
| tic = time.time() | |
| render_colors, _, _ = self.rasterize_splats( | |
| camtoworlds=camtoworlds, | |
| Ks=Ks, | |
| width=width, | |
| height=height, | |
| sh_degree=cfg.sh_degree, | |
| near_plane=cfg.near_plane, | |
| far_plane=cfg.far_plane, | |
| # radius_clip=0.1, | |
| render_mode="RGB+ED", | |
| masks=masks, | |
| ) # [1, H, W, 3] | |
| torch.cuda.synchronize() | |
| ellipse_time += time.time() - tic | |
| colors = render_colors[..., :3] | |
| depths = render_colors[..., 3] | |
| colors = torch.clamp(colors, 0.0, 1.0) | |
| canvas_list = [pixels, colors] | |
| if world_rank == 0: | |
| # write images | |
| canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() | |
| canvas = (canvas * 255).astype(np.uint8) | |
| imageio.imwrite( | |
| f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", | |
| canvas, | |
| ) | |
| torchvision.utils.save_image(pixels.permute(0, 3, 1, 2), f"{self.render_dir}/gt_rgb_{stage}_step{step}_{i:04d}.png") | |
| torchvision.utils.save_image(colors.permute(0, 3, 1, 2), f"{self.render_dir}/render_rgb_{stage}_step{step}_{i:04d}.png") | |
| # save depth & normal map | |
| pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] | |
| colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W] | |
| metrics["psnr"].append(self.psnr(colors_p, pixels_p)) | |
| metrics["ssim"].append(self.ssim(colors_p, pixels_p)) | |
| metrics["lpips"].append(self.lpips(colors_p, pixels_p)) | |
| if cfg.use_bilateral_grid: | |
| cc_colors = color_correct(colors, pixels) | |
| cc_colors_p = cc_colors.permute(0, 3, 1, 2) # [1, 3, H, W] | |
| metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) | |
| if world_rank == 0: | |
| ellipse_time /= len(valloader) | |
| stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} | |
| stats.update( | |
| { | |
| "ellipse_time": ellipse_time, | |
| "num_GS": len(self.splats["means"]), | |
| } | |
| ) | |
| print( | |
| f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " | |
| f"Time: {stats['ellipse_time']:.3f}s/image " | |
| f"Number of GS: {stats['num_GS']}" | |
| ) | |
| # save stats as json | |
| with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f: | |
| json.dump(stats, f) | |
| # save stats to tensorboard | |
| for k, v in stats.items(): | |
| self.writer.add_scalar(f"{stage}/{k}", v, step) | |
| self.writer.flush() | |
| def render_traj(self, step: int): | |
| """Entry for trajectory rendering.""" | |
| if self.cfg.disable_video: | |
| return | |
| print("Running trajectory rendering...") | |
| cfg = self.cfg | |
| device = self.device | |
| camtoworlds_all = self.parser.camtoworlds[5:-5] | |
| if cfg.render_traj_path == "interp": | |
| camtoworlds_all = generate_interpolated_path( | |
| camtoworlds_all, 1 | |
| ) # [N, 3, 4] | |
| elif cfg.render_traj_path == "ellipse": | |
| height = camtoworlds_all[:, 2, 3].mean() | |
| camtoworlds_all = generate_ellipse_path_z( | |
| camtoworlds_all, height=height | |
| ) # [N, 3, 4] | |
| elif cfg.render_traj_path == "spiral": | |
| camtoworlds_all = generate_spiral_path( | |
| camtoworlds_all, | |
| bounds=self.parser.bounds * self.scene_scale, | |
| spiral_scale_r=self.parser.extconf["spiral_radius_scale"], | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Render trajectory type not supported: {cfg.render_traj_path}" | |
| ) | |
| camtoworlds_all = np.concatenate( | |
| [ | |
| camtoworlds_all, | |
| np.repeat( | |
| np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 | |
| ), | |
| ], | |
| axis=1, | |
| ) # [N, 4, 4] | |
| camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) | |
| K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) | |
| width, height = list(self.parser.imsize_dict.values())[0] | |
| # save to video | |
| video_dir = f"{cfg.result_dir}/videos" | |
| os.makedirs(video_dir, exist_ok=True) | |
| writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) | |
| for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): | |
| camtoworlds = camtoworlds_all[i : i + 1] | |
| Ks = K[None] | |
| renders, _, _ = self.rasterize_splats( | |
| camtoworlds=camtoworlds, | |
| Ks=Ks, | |
| width=width, | |
| height=height, | |
| sh_degree=cfg.sh_degree, | |
| near_plane=cfg.near_plane, | |
| far_plane=cfg.far_plane, | |
| render_mode="RGB+ED", | |
| ) # [1, H, W, 4] | |
| colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] | |
| depths = renders[..., 3:4] # [1, H, W, 1] | |
| depths = (depths - depths.min()) / (depths.max() - depths.min()) | |
| canvas_list = [colors, depths.repeat(1, 1, 1, 3)] | |
| # write images | |
| canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() | |
| canvas = (canvas * 255).astype(np.uint8) | |
| writer.append_data(canvas) | |
| writer.close() | |
| print(f"Video saved to {video_dir}/traj_{step}.mp4") | |
| def run_compression(self, step: int): | |
| """Entry for running compression.""" | |
| print("Running compression...") | |
| world_rank = self.world_rank | |
| compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" | |
| os.makedirs(compress_dir, exist_ok=True) | |
| self.compression_method.compress(compress_dir, self.splats) | |
| # evaluate compression | |
| splats_c = self.compression_method.decompress(compress_dir) | |
| for k in splats_c.keys(): | |
| self.splats[k].data = splats_c[k].to(self.device) | |
| self.eval(step=step, stage="compress") | |
| def _viewer_render_fn( | |
| self, camera_state: CameraState, render_tab_state: RenderTabState | |
| ): | |
| assert isinstance(render_tab_state, GsplatRenderTabState) | |
| if render_tab_state.preview_render: | |
| width = render_tab_state.render_width | |
| height = render_tab_state.render_height | |
| else: | |
| width = render_tab_state.viewer_width | |
| height = render_tab_state.viewer_height | |
| c2w = camera_state.c2w | |
| K = camera_state.get_K((width, height)) | |
| c2w = torch.from_numpy(c2w).float().to(self.device) | |
| K = torch.from_numpy(K).float().to(self.device) | |
| RENDER_MODE_MAP = { | |
| "rgb": "RGB", | |
| "depth(accumulated)": "D", | |
| "depth(expected)": "ED", | |
| "alpha": "RGB", | |
| } | |
| render_colors, render_alphas, info = self.rasterize_splats( | |
| camtoworlds=c2w[None], | |
| Ks=K[None], | |
| width=width, | |
| height=height, | |
| sh_degree=min(render_tab_state.max_sh_degree, self.cfg.sh_degree), | |
| near_plane=render_tab_state.near_plane, | |
| far_plane=render_tab_state.far_plane, | |
| radius_clip=render_tab_state.radius_clip, | |
| # radius_clip=0.1, | |
| eps2d=render_tab_state.eps2d, | |
| backgrounds=torch.tensor([render_tab_state.backgrounds], device=self.device) | |
| / 255.0, | |
| render_mode=RENDER_MODE_MAP[render_tab_state.render_mode], | |
| rasterize_mode=render_tab_state.rasterize_mode, | |
| camera_model=render_tab_state.camera_model, | |
| ) # [1, H, W, 3] | |
| render_tab_state.total_gs_count = len(self.splats["means"]) | |
| render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item() | |
| if render_tab_state.render_mode == "rgb": | |
| # colors represented with sh are not guranteed to be in [0, 1] | |
| render_colors = render_colors[0, ..., 0:3].clamp(0, 1) | |
| renders = render_colors.cpu().numpy() | |
| elif render_tab_state.render_mode in ["depth(accumulated)", "depth(expected)"]: | |
| # normalize depth to [0, 1] | |
| depth = render_colors[0, ..., 0:1] | |
| if render_tab_state.normalize_nearfar: | |
| near_plane = render_tab_state.near_plane | |
| far_plane = render_tab_state.far_plane | |
| else: | |
| near_plane = depth.min() | |
| far_plane = depth.max() | |
| depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10) | |
| depth_norm = torch.clip(depth_norm, 0, 1) | |
| if render_tab_state.inverse: | |
| depth_norm = 1 - depth_norm | |
| renders = ( | |
| apply_float_colormap(depth_norm, render_tab_state.colormap) | |
| .cpu() | |
| .numpy() | |
| ) | |
| elif render_tab_state.render_mode == "alpha": | |
| alpha = render_alphas[0, ..., 0:1] | |
| if render_tab_state.inverse: | |
| alpha = 1 - alpha | |
| renders = ( | |
| apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy() | |
| ) | |
| return renders | |
| def main(local_rank: int, world_rank, world_size: int, cfg: Config): | |
| if world_size > 1 and not cfg.disable_viewer: | |
| cfg.disable_viewer = True | |
| if world_rank == 0: | |
| print("Viewer is disabled in distributed training.") | |
| runner = Runner(local_rank, world_rank, world_size, cfg) | |
| if cfg.ckpt is not None: | |
| # run eval only | |
| ckpts = [ | |
| torch.load(file, map_location=runner.device, weights_only=True) | |
| for file in cfg.ckpt | |
| ] | |
| for k in runner.splats.keys(): | |
| runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) | |
| step = ckpts[0]["step"] | |
| runner.eval(step=step) | |
| # runner.render_traj(step=step) | |
| if cfg.compression is not None: | |
| runner.run_compression(step=step) | |
| else: | |
| runner.train() | |
| runner.eval(step=runner.cfg.max_steps) | |
| # runner.render_traj(step=runner.cfg.max_steps) | |
| print("Training complete.") | |
| # runner.viewer.complete() | |
| # if not cfg.disable_viewer: | |
| # print("Viewer running... Ctrl+C to exit.") | |
| # time.sleep(1000000) | |
| if __name__ == "__main__": | |
| """ | |
| Usage: | |
| ```bash | |
| # Single GPU training | |
| CUDA_VISIBLE_DEVICES=9 python -m examples.simple_trainer default | |
| # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. | |
| CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 | |
| """ | |
| # Config objects we can choose between. | |
| # Each is a tuple of (CLI description, config object). | |
| configs = { | |
| "default": ( | |
| "Gaussian splatting training using densification heuristics from the original paper.", | |
| Config( | |
| strategy=DefaultStrategy(verbose=True), | |
| ), | |
| ), | |
| "mcmc": ( | |
| "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", | |
| Config( | |
| init_opa=0.5, | |
| init_scale=0.1, | |
| opacity_reg=0.01, | |
| scale_reg=0.01, | |
| strategy=MCMCStrategy(verbose=True), | |
| ), | |
| ), | |
| } | |
| cfg = tyro.extras.overridable_config_cli(configs) | |
| cfg.adjust_steps(cfg.steps_scaler) | |
| # try import extra dependencies | |
| if cfg.compression == "png": | |
| try: | |
| import plas | |
| import torchpq | |
| except: | |
| raise ImportError( | |
| "To use PNG compression, you need to install " | |
| "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " | |
| "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " | |
| ) | |
| cli(main, cfg, verbose=True) | |