Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import gc | |
| import random | |
| from typing import Literal, Optional, Protocol, runtime_checkable, Any | |
| import moviepy.editor as mpy | |
| import torch | |
| import torchvision | |
| import wandb | |
| from einops import pack, rearrange, repeat | |
| from jaxtyping import Float | |
| from lightning.pytorch import LightningModule | |
| from lightning.pytorch.loggers.wandb import WandbLogger | |
| from lightning.pytorch.utilities import rank_zero_only | |
| from tabulate import tabulate | |
| from torch import Tensor, nn, optim | |
| import torch.nn.functional as F | |
| from loss.loss_lpips import LossLpips | |
| from loss.loss_mse import LossMse | |
| from model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri | |
| from ..loss.loss_distill import DistillLoss | |
| from src.utils.render import generate_path | |
| from src.utils.point import get_normal_map | |
| from ..loss.loss_huber import HuberLoss, extri_intri_to_pose_encoding | |
| # from model.types import Gaussians | |
| from ..dataset.data_module import get_data_shim | |
| from ..dataset.types import BatchedExample | |
| from ..evaluation.metrics import compute_lpips, compute_psnr, compute_ssim, abs_relative_difference, delta1_acc | |
| from ..global_cfg import get_cfg | |
| from ..loss import Loss | |
| from ..loss.loss_point import Regr3D | |
| from ..loss.loss_ssim import ssim | |
| from ..misc.benchmarker import Benchmarker | |
| from ..misc.cam_utils import update_pose, get_pnp_pose, rotation_6d_to_matrix | |
| from ..misc.image_io import prep_image, save_image, save_video | |
| from ..misc.LocalLogger import LOG_PATH, LocalLogger | |
| from ..misc.nn_module_tools import convert_to_buffer | |
| from ..misc.step_tracker import StepTracker | |
| from ..misc.utils import inverse_normalize, vis_depth_map, confidence_map, get_overlap_tag | |
| from ..visualization.annotation import add_label | |
| from ..visualization.camera_trajectory.interpolation import ( | |
| interpolate_extrinsics, | |
| interpolate_intrinsics, | |
| ) | |
| from ..visualization.camera_trajectory.wobble import ( | |
| generate_wobble, | |
| generate_wobble_transformation, | |
| ) | |
| from ..visualization.color_map import apply_color_map_to_image | |
| from ..visualization.layout import add_border, hcat, vcat | |
| # from ..visualization.validation_in_3d import render_cameras, render_projections | |
| from .decoder.decoder import Decoder, DepthRenderingMode | |
| from .encoder import Encoder | |
| from .encoder.visualization.encoder_visualizer import EncoderVisualizer | |
| from .ply_export import export_ply | |
| class OptimizerCfg: | |
| lr: float | |
| warm_up_steps: int | |
| backbone_lr_multiplier: float | |
| class TestCfg: | |
| output_path: Path | |
| align_pose: bool | |
| pose_align_steps: int | |
| rot_opt_lr: float | |
| trans_opt_lr: float | |
| compute_scores: bool | |
| save_image: bool | |
| save_video: bool | |
| save_compare: bool | |
| generate_video: bool | |
| mode: Literal["inference", "evaluation"] | |
| image_folder: str | |
| class TrainCfg: | |
| output_path: Path | |
| depth_mode: DepthRenderingMode | None | |
| extended_visualization: bool | |
| print_log_every_n_steps: int | |
| distiller: str | |
| distill_max_steps: int | |
| pose_loss_alpha: float = 1.0 | |
| pose_loss_delta: float = 1.0 | |
| cxt_depth_weight: float = 0.01 | |
| weight_pose: float = 1.0 | |
| weight_depth: float = 1.0 | |
| weight_normal: float = 1.0 | |
| render_ba: bool = False | |
| render_ba_after_step: int = 0 | |
| class TrajectoryFn(Protocol): | |
| def __call__( | |
| self, | |
| t: Float[Tensor, " t"], | |
| ) -> tuple[ | |
| Float[Tensor, "batch view 4 4"], # extrinsics | |
| Float[Tensor, "batch view 3 3"], # intrinsics | |
| ]: | |
| pass | |
| class ModelWrapper(LightningModule): | |
| logger: Optional[WandbLogger] | |
| model: nn.Module | |
| losses: nn.ModuleList | |
| optimizer_cfg: OptimizerCfg | |
| test_cfg: TestCfg | |
| train_cfg: TrainCfg | |
| step_tracker: StepTracker | None | |
| def __init__( | |
| self, | |
| optimizer_cfg: OptimizerCfg, | |
| test_cfg: TestCfg, | |
| train_cfg: TrainCfg, | |
| model: nn.Module, | |
| losses: list[Loss], | |
| step_tracker: StepTracker | None | |
| ) -> None: | |
| super().__init__() | |
| self.optimizer_cfg = optimizer_cfg | |
| self.test_cfg = test_cfg | |
| self.train_cfg = train_cfg | |
| self.step_tracker = step_tracker | |
| # Set up the model. | |
| self.encoder_visualizer = None | |
| self.model = model | |
| self.data_shim = get_data_shim(self.model.encoder) | |
| self.losses = nn.ModuleList(losses) | |
| if self.model.encoder.pred_pose: | |
| self.loss_pose = HuberLoss(alpha=self.train_cfg.pose_loss_alpha, delta=self.train_cfg.pose_loss_delta) | |
| if self.model.encoder.distill: | |
| self.loss_distill = DistillLoss( | |
| delta=self.train_cfg.pose_loss_delta, | |
| weight_pose=self.train_cfg.weight_pose, | |
| weight_depth=self.train_cfg.weight_depth, | |
| weight_normal=self.train_cfg.weight_normal | |
| ) | |
| # This is used for testing. | |
| self.benchmarker = Benchmarker() | |
| def on_train_epoch_start(self) -> None: | |
| # our custom dataset and sampler has to have epoch set by calling set_epoch | |
| if hasattr(self.trainer.datamodule.train_loader.dataset, "set_epoch"): | |
| self.trainer.datamodule.train_loader.dataset.set_epoch(self.current_epoch) | |
| if hasattr(self.trainer.datamodule.train_loader.sampler, "set_epoch"): | |
| self.trainer.datamodule.train_loader.sampler.set_epoch(self.current_epoch) | |
| def on_validation_epoch_start(self) -> None: | |
| print(f"Validation epoch start on rank {self.trainer.global_rank}") | |
| # our custom dataset and sampler has to have epoch set by calling set_epoch | |
| if hasattr(self.trainer.datamodule.val_loader.dataset, "set_epoch"): | |
| self.trainer.datamodule.val_loader.dataset.set_epoch(self.current_epoch) | |
| if hasattr(self.trainer.datamodule.val_loader.sampler, "set_epoch"): | |
| self.trainer.datamodule.val_loader.sampler.set_epoch(self.current_epoch) | |
| def training_step(self, batch, batch_idx): | |
| # combine batch from different dataloaders | |
| # torch.cuda.empty_cache() | |
| if isinstance(batch, list): | |
| batch_combined = None | |
| for batch_per_dl in batch: | |
| if batch_combined is None: | |
| batch_combined = batch_per_dl | |
| else: | |
| for k in batch_combined.keys(): | |
| if isinstance(batch_combined[k], list): | |
| batch_combined[k] += batch_per_dl[k] | |
| elif isinstance(batch_combined[k], dict): | |
| for kk in batch_combined[k].keys(): | |
| batch_combined[k][kk] = torch.cat([batch_combined[k][kk], batch_per_dl[k][kk]], dim=0) | |
| else: | |
| raise NotImplementedError | |
| batch = batch_combined | |
| batch: BatchedExample = self.data_shim(batch) | |
| b, v, c, h, w = batch["context"]["image"].shape | |
| context_image = (batch["context"]["image"] + 1) / 2 | |
| # Run the model. | |
| visualization_dump = None | |
| encoder_output, output = self.model(context_image, self.global_step, visualization_dump=visualization_dump) | |
| gaussians, pred_pose_enc_list, depth_dict = encoder_output.gaussians, encoder_output.pred_pose_enc_list, encoder_output.depth_dict | |
| pred_context_pose = encoder_output.pred_context_pose | |
| infos = encoder_output.infos | |
| distill_infos = encoder_output.distill_infos | |
| num_context_views = pred_context_pose['extrinsic'].shape[1] | |
| using_index = torch.arange(num_context_views, device=gaussians.means.device) | |
| batch["using_index"] = using_index | |
| target_gt = (batch["context"]["image"] + 1) / 2 | |
| scene_scale = infos["scene_scale"] | |
| self.log("train/scene_scale", infos["scene_scale"]) | |
| self.log("train/voxelize_ratio", infos["voxelize_ratio"]) | |
| # Compute metrics. | |
| psnr_probabilistic = compute_psnr( | |
| rearrange(target_gt, "b v c h w -> (b v) c h w"), | |
| rearrange(output.color, "b v c h w -> (b v) c h w"), | |
| ) | |
| self.log("train/psnr_probabilistic", psnr_probabilistic.mean()) | |
| consis_absrel = abs_relative_difference( | |
| rearrange(output.depth, "b v h w -> (b v) h w"), | |
| rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), | |
| rearrange(distill_infos['conf_mask'], "b v h w -> (b v) h w"), | |
| ) | |
| self.log("train/consis_absrel", consis_absrel.mean()) | |
| consis_delta1 = delta1_acc( | |
| rearrange(output.depth, "b v h w -> (b v) h w"), | |
| rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), | |
| rearrange(distill_infos['conf_mask'], "b v h w -> (b v) h w"), | |
| ) | |
| self.log("train/consis_delta1", consis_delta1.mean()) | |
| # Compute and log loss. | |
| total_loss = 0 | |
| depth_dict['distill_infos'] = distill_infos | |
| with torch.amp.autocast('cuda', enabled=False): | |
| for loss_fn in self.losses: | |
| loss = loss_fn.forward(output, batch, gaussians, depth_dict, self.global_step) | |
| self.log(f"loss/{loss_fn.name}", loss) | |
| total_loss = total_loss + loss | |
| if depth_dict is not None and "depth" in get_cfg()["loss"].keys() and self.train_cfg.cxt_depth_weight > 0: | |
| depth_loss_idx = list(get_cfg()["loss"].keys()).index("depth") | |
| depth_loss_fn = self.losses[depth_loss_idx].ctx_depth_loss | |
| loss_depth = depth_loss_fn(depth_dict["depth_map"], depth_dict["depth_conf"], batch, cxt_depth_weight=self.train_cfg.cxt_depth_weight) | |
| self.log("loss/ctx_depth", loss_depth) | |
| total_loss = total_loss + loss_depth | |
| if distill_infos is not None: | |
| # distill ctx pred_pose & depth & normal | |
| loss_distill_list = self.loss_distill(distill_infos, pred_pose_enc_list, output, batch) | |
| self.log("loss/distill", loss_distill_list['loss_distill']) | |
| self.log("loss/distill_pose", loss_distill_list['loss_pose']) | |
| self.log("loss/distill_depth", loss_distill_list['loss_depth']) | |
| self.log("loss/distill_normal", loss_distill_list['loss_normal']) | |
| total_loss = total_loss + loss_distill_list['loss_distill'] | |
| self.log("loss/total", total_loss) | |
| print(f"total_loss: {total_loss}") | |
| # Skip batch if loss is too high after certain step | |
| SKIP_AFTER_STEP = 1000 | |
| LOSS_THRESHOLD = 0.2 | |
| if self.global_step > SKIP_AFTER_STEP and total_loss > LOSS_THRESHOLD: | |
| print(f"Skipping batch with high loss ({total_loss:.6f}) at step {self.global_step} on Rank {self.global_rank}") | |
| # set to a really small number | |
| return total_loss * 1e-10 | |
| if ( | |
| self.global_rank == 0 | |
| and self.global_step % self.train_cfg.print_log_every_n_steps == 0 | |
| ): | |
| print( | |
| f"train step {self.global_step}; " | |
| f"scene = {[x[:20] for x in batch['scene']]}; " | |
| f"context = {batch['context']['index'].tolist()}; " | |
| f"loss = {total_loss:.6f}; " | |
| ) | |
| self.log("info/global_step", self.global_step) # hack for ckpt monitor | |
| # Tell the data loader processes about the current step. | |
| if self.step_tracker is not None: | |
| self.step_tracker.set_step(self.global_step) | |
| del batch | |
| if self.global_step % 50 == 0: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return total_loss | |
| def on_after_backward(self): | |
| total_norm = 0.0 | |
| counter = 0 | |
| for p in self.parameters(): | |
| if p.grad is not None: | |
| param_norm = p.grad.detach().data.norm(2) | |
| total_norm += param_norm.item() ** 2 | |
| counter += 1 | |
| total_norm = (total_norm / counter) ** 0.5 | |
| self.log("loss/grad_norm", total_norm) | |
| def test_step(self, batch, batch_idx): | |
| batch: BatchedExample = self.data_shim(batch) | |
| b, v, _, h, w = batch["target"]["image"].shape | |
| assert b == 1 | |
| if batch_idx % 100 == 0: | |
| print(f"Test step {batch_idx:0>6}.") | |
| # Render Gaussians. | |
| with self.benchmarker.time("encoder"): | |
| gaussians = self.model.encoder( | |
| (batch["context"]["image"]+1)/2, | |
| self.global_step, | |
| )[0] | |
| # export_ply(gaussians.means[0], gaussians.scales[0], gaussians.rotations[0], gaussians.harmonics[0], gaussians.opacities[0], Path("gaussians.ply")) | |
| # align the target pose | |
| if self.test_cfg.align_pose: | |
| output = self.test_step_align(batch, gaussians) | |
| else: | |
| with self.benchmarker.time("decoder", num_calls=v): | |
| output = self.model.decoder.forward( | |
| gaussians, | |
| batch["target"]["extrinsics"], | |
| batch["target"]["intrinsics"], | |
| batch["target"]["near"], | |
| batch["target"]["far"], | |
| (h, w), | |
| ) | |
| # compute scores | |
| if self.test_cfg.compute_scores: | |
| overlap = batch["context"]["overlap"][0] | |
| overlap_tag = get_overlap_tag(overlap) | |
| rgb_pred = output.color[0] | |
| rgb_gt = batch["target"]["image"][0] | |
| all_metrics = { | |
| f"lpips_ours": compute_lpips(rgb_gt, rgb_pred).mean(), | |
| f"ssim_ours": compute_ssim(rgb_gt, rgb_pred).mean(), | |
| f"psnr_ours": compute_psnr(rgb_gt, rgb_pred).mean(), | |
| } | |
| methods = ['ours'] | |
| self.log_dict(all_metrics) | |
| self.print_preview_metrics(all_metrics, methods, overlap_tag=overlap_tag) | |
| # Save images. | |
| (scene,) = batch["scene"] | |
| name = get_cfg()["wandb"]["name"] | |
| path = self.test_cfg.output_path / name | |
| if self.test_cfg.save_image: | |
| for index, color in zip(batch["target"]["index"][0], output.color[0]): | |
| save_image(color, path / scene / f"color/{index:0>6}.png") | |
| if self.test_cfg.save_video: | |
| frame_str = "_".join([str(x.item()) for x in batch["context"]["index"][0]]) | |
| save_video( | |
| [a for a in output.color[0]], | |
| path / "video" / f"{scene}_frame_{frame_str}.mp4", | |
| ) | |
| if self.test_cfg.save_compare: | |
| # Construct comparison image. | |
| context_img = inverse_normalize(batch["context"]["image"][0]) | |
| comparison = hcat( | |
| add_label(vcat(*context_img), "Context"), | |
| add_label(vcat(*rgb_gt), "Target (Ground Truth)"), | |
| add_label(vcat(*rgb_pred), "Target (Prediction)"), | |
| ) | |
| save_image(comparison, path / f"{scene}.png") | |
| def test_step_align(self, batch, gaussians): | |
| self.model.encoder.eval() | |
| # freeze all parameters | |
| for param in self.model.encoder.parameters(): | |
| param.requires_grad = False | |
| b, v, _, h, w = batch["target"]["image"].shape | |
| output_c2ws = batch["target"]["extrinsics"] | |
| with torch.set_grad_enabled(True): | |
| cam_rot_delta = nn.Parameter(torch.zeros([b, v, 6], requires_grad=True, device=output_c2ws.device)) | |
| cam_trans_delta = nn.Parameter(torch.zeros([b, v, 3], requires_grad=True, device=output_c2ws.device)) | |
| opt_params = [] | |
| self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).to(output_c2ws)) | |
| opt_params.append( | |
| { | |
| "params": [cam_rot_delta], | |
| "lr": 0.005, | |
| } | |
| ) | |
| opt_params.append( | |
| { | |
| "params": [cam_trans_delta], | |
| "lr": 0.005, | |
| } | |
| ) | |
| pose_optimizer = torch.optim.Adam(opt_params) | |
| extrinsics = output_c2ws.clone() | |
| with self.benchmarker.time("optimize"): | |
| for i in range(self.test_cfg.pose_align_steps): | |
| pose_optimizer.zero_grad() | |
| dx, drot = cam_trans_delta, cam_rot_delta | |
| rot = rotation_6d_to_matrix( | |
| drot + self.identity.expand(b, v, -1) | |
| ) # (..., 3, 3) | |
| transform = torch.eye(4, device=extrinsics.device).repeat((b, v, 1, 1)) | |
| transform[..., :3, :3] = rot | |
| transform[..., :3, 3] = dx | |
| new_extrinsics = torch.matmul(extrinsics, transform) | |
| output = self.model.decoder.forward( | |
| gaussians, | |
| new_extrinsics, | |
| batch["target"]["intrinsics"], | |
| batch["target"]["near"], | |
| batch["target"]["far"], | |
| (h, w), | |
| # cam_rot_delta=cam_rot_delta, | |
| # cam_trans_delta=cam_trans_delta, | |
| ) | |
| # Compute and log loss. | |
| total_loss = 0 | |
| for loss_fn in self.losses: | |
| loss = loss_fn.forward(output, batch, gaussians, self.global_step) | |
| total_loss = total_loss + loss | |
| total_loss.backward() | |
| pose_optimizer.step() | |
| # Render Gaussians. | |
| output = self.model.decoder.forward( | |
| gaussians, | |
| new_extrinsics, | |
| batch["target"]["intrinsics"], | |
| batch["target"]["near"], | |
| batch["target"]["far"], | |
| (h, w), | |
| ) | |
| return output | |
| def on_test_end(self) -> None: | |
| name = get_cfg()["wandb"]["name"] | |
| self.benchmarker.dump(self.test_cfg.output_path / name / "benchmark.json") | |
| self.benchmarker.dump_memory( | |
| self.test_cfg.output_path / name / "peak_memory.json" | |
| ) | |
| self.benchmarker.summarize() | |
| def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
| batch: BatchedExample = self.data_shim(batch) | |
| if self.global_rank == 0: | |
| print( | |
| f"validation step {self.global_step}; " | |
| f"scene = {batch['scene']}; " | |
| f"context = {batch['context']['index'].tolist()}" | |
| ) | |
| # Render Gaussians. | |
| b, v, _, h, w = batch["context"]["image"].shape | |
| assert b == 1 | |
| visualization_dump = {} | |
| encoder_output, output = self.model(batch["context"]["image"], self.global_step, visualization_dump=visualization_dump) | |
| gaussians, pred_pose_enc_list, depth_dict = encoder_output.gaussians, encoder_output.pred_pose_enc_list, encoder_output.depth_dict | |
| pred_context_pose, distill_infos = encoder_output.pred_context_pose, encoder_output.distill_infos | |
| infos = encoder_output.infos | |
| GS_num = infos['voxelize_ratio'] * (h*w*v) | |
| self.log("val/GS_num", GS_num) | |
| num_context_views = pred_context_pose['extrinsic'].shape[1] | |
| num_target_views = batch["target"]["extrinsics"].shape[1] | |
| rgb_pred = output.color[0].float() | |
| depth_pred = vis_depth_map(output.depth[0]) | |
| # direct depth from gaussian means (used for visualization only) | |
| gaussian_means = visualization_dump["depth"][0].squeeze() | |
| if gaussian_means.shape[-1] == 3: | |
| gaussian_means = gaussian_means.mean(dim=-1) | |
| # Compute validation metrics. | |
| rgb_gt = (batch["context"]["image"][0].float() + 1) / 2 | |
| psnr = compute_psnr(rgb_gt, rgb_pred).mean() | |
| self.log(f"val/psnr", psnr) | |
| lpips = compute_lpips(rgb_gt, rgb_pred).mean() | |
| self.log(f"val/lpips", lpips) | |
| ssim = compute_ssim(rgb_gt, rgb_pred).mean() | |
| self.log(f"val/ssim", ssim) | |
| # depth metrics | |
| consis_absrel = abs_relative_difference( | |
| rearrange(output.depth, "b v h w -> (b v) h w"), | |
| rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), | |
| ) | |
| self.log("val/consis_absrel", consis_absrel.mean()) | |
| consis_delta1 = delta1_acc( | |
| rearrange(output.depth, "b v h w -> (b v) h w"), | |
| rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), | |
| valid_mask=rearrange(torch.ones_like(output.depth, device=output.depth.device, dtype=torch.bool), "b v h w -> (b v) h w"), | |
| ) | |
| self.log("val/consis_delta1", consis_delta1.mean()) | |
| diff_map = torch.abs(output.depth - depth_dict['depth'].squeeze(-1)) | |
| self.log("val/consis_mse", diff_map[distill_infos['conf_mask']].mean()) | |
| # Construct comparison image. | |
| context_img = inverse_normalize(batch["context"]["image"][0]) | |
| # context_img_depth = vis_depth_map(gaussian_means) | |
| context = [] | |
| for i in range(context_img.shape[0]): | |
| context.append(context_img[i]) | |
| # context.append(context_img_depth[i]) | |
| colored_diff_map = vis_depth_map(diff_map[0], near=torch.tensor(1e-4, device=diff_map.device), far=torch.tensor(1.0, device=diff_map.device)) | |
| model_depth_pred = depth_dict["depth"].squeeze(-1)[0] | |
| model_depth_pred = vis_depth_map(model_depth_pred) | |
| render_normal = (get_normal_map(output.depth.flatten(0, 1), batch["context"]["intrinsics"].flatten(0, 1)).permute(0, 3, 1, 2) + 1) / 2. | |
| pred_normal = (get_normal_map(depth_dict['depth'].flatten(0, 1).squeeze(-1), batch["context"]["intrinsics"].flatten(0, 1)).permute(0, 3, 1, 2) + 1) / 2. | |
| comparison = hcat( | |
| add_label(vcat(*context), "Context"), | |
| add_label(vcat(*rgb_gt), "Target (Ground Truth)"), | |
| add_label(vcat(*rgb_pred), "Target (Prediction)"), | |
| add_label(vcat(*depth_pred), "Depth (Prediction)"), | |
| add_label(vcat(*model_depth_pred), "Depth (VGGT Prediction)"), | |
| add_label(vcat(*render_normal), "Normal (Prediction)"), | |
| add_label(vcat(*pred_normal), "Normal (VGGT Prediction)"), | |
| add_label(vcat(*colored_diff_map), "Diff Map"), | |
| ) | |
| comparison = torch.nn.functional.interpolate( | |
| comparison.unsqueeze(0), | |
| scale_factor=0.5, | |
| mode='bicubic', | |
| align_corners=False | |
| ).squeeze(0) | |
| self.logger.log_image( | |
| "comparison", | |
| [prep_image(add_border(comparison))], | |
| step=self.global_step, | |
| caption=batch["scene"], | |
| ) | |
| # self.logger.log_image( | |
| # key="comparison", | |
| # images=[wandb.Image(prep_image(add_border(comparison)), caption=batch["scene"], file_type="jpg")], | |
| # step=self.global_step | |
| # ) | |
| # Render projections and construct projection image. | |
| # These are disabled for now, since RE10k scenes are effectively unbounded. | |
| # if isinstance(gaussians, Gaussians): | |
| # projections = hcat( | |
| # *render_projections( | |
| # gaussians, | |
| # 256, | |
| # extra_label="", | |
| # )[0] | |
| # ) | |
| # self.logger.log_image( | |
| # "projection", | |
| # [prep_image(add_border(projections))], | |
| # step=self.global_step, | |
| # ) | |
| # Draw cameras. | |
| # cameras = hcat(*render_cameras(batch, 256)) | |
| # self.logger.log_image( | |
| # "cameras", [prep_image(add_border(cameras))], step=self.global_step | |
| # ) | |
| if self.encoder_visualizer is not None: | |
| for k, image in self.encoder_visualizer.visualize( | |
| batch["context"], self.global_step | |
| ).items(): | |
| self.logger.log_image(k, [prep_image(image)], step=self.global_step) | |
| # Run video validation step. | |
| self.render_video_interpolation(batch) | |
| self.render_video_wobble(batch) | |
| if self.train_cfg.extended_visualization: | |
| self.render_video_interpolation_exaggerated(batch) | |
| def render_video_wobble(self, batch: BatchedExample) -> None: | |
| # Two views are needed to get the wobble radius. | |
| _, v, _, _ = batch["context"]["extrinsics"].shape | |
| if v != 2: | |
| return | |
| def trajectory_fn(t): | |
| origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] | |
| origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] | |
| delta = (origin_a - origin_b).norm(dim=-1) | |
| extrinsics = generate_wobble( | |
| batch["context"]["extrinsics"][:, 0], | |
| delta * 0.25, | |
| t, | |
| ) | |
| intrinsics = repeat( | |
| batch["context"]["intrinsics"][:, 0], | |
| "b i j -> b v i j", | |
| v=t.shape[0], | |
| ) | |
| return extrinsics, intrinsics | |
| return self.render_video_generic(batch, trajectory_fn, "wobble", num_frames=60) | |
| def render_video_interpolation(self, batch: BatchedExample) -> None: | |
| _, v, _, _ = batch["context"]["extrinsics"].shape | |
| def trajectory_fn(t): | |
| extrinsics = interpolate_extrinsics( | |
| batch["context"]["extrinsics"][0, 0], | |
| ( | |
| batch["context"]["extrinsics"][0, 1] | |
| if v == 2 | |
| else batch["target"]["extrinsics"][0, 0] | |
| ), | |
| t, | |
| ) | |
| intrinsics = interpolate_intrinsics( | |
| batch["context"]["intrinsics"][0, 0], | |
| ( | |
| batch["context"]["intrinsics"][0, 1] | |
| if v == 2 | |
| else batch["target"]["intrinsics"][0, 0] | |
| ), | |
| t, | |
| ) | |
| return extrinsics[None], intrinsics[None] | |
| return self.render_video_generic(batch, trajectory_fn, "rgb") | |
| def render_video_interpolation_exaggerated(self, batch: BatchedExample) -> None: | |
| # Two views are needed to get the wobble radius. | |
| _, v, _, _ = batch["context"]["extrinsics"].shape | |
| if v != 2: | |
| return | |
| def trajectory_fn(t): | |
| origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] | |
| origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] | |
| delta = (origin_a - origin_b).norm(dim=-1) | |
| tf = generate_wobble_transformation( | |
| delta * 0.5, | |
| t, | |
| 5, | |
| scale_radius_with_t=False, | |
| ) | |
| extrinsics = interpolate_extrinsics( | |
| batch["context"]["extrinsics"][0, 0], | |
| ( | |
| batch["context"]["extrinsics"][0, 1] | |
| if v == 2 | |
| else batch["target"]["extrinsics"][0, 0] | |
| ), | |
| t * 5 - 2, | |
| ) | |
| intrinsics = interpolate_intrinsics( | |
| batch["context"]["intrinsics"][0, 0], | |
| ( | |
| batch["context"]["intrinsics"][0, 1] | |
| if v == 2 | |
| else batch["target"]["intrinsics"][0, 0] | |
| ), | |
| t * 5 - 2, | |
| ) | |
| return extrinsics @ tf, intrinsics[None] | |
| return self.render_video_generic( | |
| batch, | |
| trajectory_fn, | |
| "interpolation_exagerrated", | |
| num_frames=300, | |
| smooth=False, | |
| loop_reverse=False, | |
| ) | |
| def render_video_generic( | |
| self, | |
| batch: BatchedExample, | |
| trajectory_fn: TrajectoryFn, | |
| name: str, | |
| num_frames: int = 30, | |
| smooth: bool = True, | |
| loop_reverse: bool = True, | |
| ) -> None: | |
| # Render probabilistic estimate of scene. | |
| encoder_output = self.model.encoder((batch["context"]["image"]+1)/2, self.global_step) | |
| gaussians, pred_pose_enc_list = encoder_output.gaussians, encoder_output.pred_pose_enc_list | |
| t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=self.device) | |
| if smooth: | |
| t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 | |
| extrinsics, intrinsics = trajectory_fn(t) | |
| _, _, _, h, w = batch["context"]["image"].shape | |
| # TODO: Interpolate near and far planes? | |
| near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames) | |
| far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames) | |
| output = self.model.decoder.forward( | |
| gaussians, extrinsics, intrinsics, near, far, (h, w), "depth" | |
| ) | |
| images = [ | |
| vcat(rgb, depth) | |
| for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0])) | |
| ] | |
| video = torch.stack(images) | |
| video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy() | |
| if loop_reverse: | |
| video = pack([video, video[::-1][1:-1]], "* c h w")[0] | |
| visualizations = { | |
| f"video/{name}": wandb.Video(video[None], fps=30, format="mp4") | |
| } | |
| # Since the PyTorch Lightning doesn't support video logging, log to wandb directly. | |
| try: | |
| wandb.log(visualizations) | |
| except Exception: | |
| assert isinstance(self.logger, LocalLogger) | |
| for key, value in visualizations.items(): | |
| tensor = value._prepare_video(value.data) | |
| clip = mpy.ImageSequenceClip(list(tensor), fps=30) | |
| dir = LOG_PATH / key | |
| dir.mkdir(exist_ok=True, parents=True) | |
| clip.write_videofile( | |
| str(dir / f"{self.global_step:0>6}.mp4"), logger=None | |
| ) | |
| def print_preview_metrics(self, metrics: dict[str, float | Tensor], methods: list[str] | None = None, overlap_tag: str | None = None) -> None: | |
| if getattr(self, "running_metrics", None) is None: | |
| self.running_metrics = metrics | |
| self.running_metric_steps = 1 | |
| else: | |
| s = self.running_metric_steps | |
| self.running_metrics = { | |
| k: ((s * v) + metrics[k]) / (s + 1) | |
| for k, v in self.running_metrics.items() | |
| } | |
| self.running_metric_steps += 1 | |
| if overlap_tag is not None: | |
| if getattr(self, "running_metrics_sub", None) is None: | |
| self.running_metrics_sub = {overlap_tag: metrics} | |
| self.running_metric_steps_sub = {overlap_tag: 1} | |
| elif overlap_tag not in self.running_metrics_sub: | |
| self.running_metrics_sub[overlap_tag] = metrics | |
| self.running_metric_steps_sub[overlap_tag] = 1 | |
| else: | |
| s = self.running_metric_steps_sub[overlap_tag] | |
| self.running_metrics_sub[overlap_tag] = {k: ((s * v) + metrics[k]) / (s + 1) | |
| for k, v in self.running_metrics_sub[overlap_tag].items()} | |
| self.running_metric_steps_sub[overlap_tag] += 1 | |
| metric_list = ["psnr", "lpips", "ssim"] | |
| def print_metrics(runing_metric, methods=None): | |
| table = [] | |
| if methods is None: | |
| methods = ['ours'] | |
| for method in methods: | |
| row = [ | |
| f"{runing_metric[f'{metric}_{method}']:.3f}" | |
| for metric in metric_list | |
| ] | |
| table.append((method, *row)) | |
| headers = ["Method"] + metric_list | |
| table = tabulate(table, headers) | |
| print(table) | |
| print("All Pairs:") | |
| print_metrics(self.running_metrics, methods) | |
| if overlap_tag is not None: | |
| for k, v in self.running_metrics_sub.items(): | |
| print(f"Overlap: {k}") | |
| print_metrics(v, methods) | |
| def configure_optimizers(self): | |
| new_params, new_param_names = [], [] | |
| pretrained_params, pretrained_param_names = [], [] | |
| for name, param in self.named_parameters(): | |
| if not param.requires_grad: | |
| continue | |
| if "gaussian_param_head" in name or "interm" in name: | |
| new_params.append(param) | |
| new_param_names.append(name) | |
| else: | |
| pretrained_params.append(param) | |
| pretrained_param_names.append(name) | |
| param_dicts = [ | |
| { | |
| "params": new_params, | |
| "lr": self.optimizer_cfg.lr, | |
| }, | |
| { | |
| "params": pretrained_params, | |
| "lr": self.optimizer_cfg.lr * self.optimizer_cfg.backbone_lr_multiplier, | |
| }, | |
| ] | |
| optimizer = torch.optim.AdamW(param_dicts, lr=self.optimizer_cfg.lr, weight_decay=0.05, betas=(0.9, 0.95)) | |
| warm_up_steps = self.optimizer_cfg.warm_up_steps | |
| warm_up = torch.optim.lr_scheduler.LinearLR( | |
| optimizer, | |
| 1 / warm_up_steps, | |
| 1, | |
| total_iters=warm_up_steps, | |
| ) | |
| lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=get_cfg()["trainer"]["max_steps"], eta_min=self.optimizer_cfg.lr * 0.1) | |
| lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warm_up, lr_scheduler], milestones=[warm_up_steps]) | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": { | |
| "scheduler": lr_scheduler, | |
| "interval": "step", | |
| "frequency": 1, | |
| }, | |
| } | |