| | from dataclasses import dataclass |
| | from pathlib import Path |
| | import gc |
| | import random |
| | from typing import Literal, Optional, Protocol, runtime_checkable, Any |
| |
|
| | import moviepy.editor as mpy |
| | import torch |
| | import torchvision |
| | import wandb |
| | from einops import pack, rearrange, repeat |
| | from jaxtyping import Float |
| | from lightning.pytorch import LightningModule |
| | from lightning.pytorch.loggers.wandb import WandbLogger |
| | from lightning.pytorch.utilities import rank_zero_only |
| | from tabulate import tabulate |
| | from torch import Tensor, nn, optim |
| | import torch.nn.functional as F |
| |
|
| | from loss.loss_lpips import LossLpips |
| | from loss.loss_mse import LossMse |
| | from model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
| |
|
| | from ..loss.loss_distill import DistillLoss |
| | from src.utils.render import generate_path |
| | from src.utils.point import get_normal_map |
| |
|
| | from ..loss.loss_huber import HuberLoss, extri_intri_to_pose_encoding |
| |
|
| | |
| |
|
| | from ..dataset.data_module import get_data_shim |
| | from ..dataset.types import BatchedExample |
| | from ..evaluation.metrics import compute_lpips, compute_psnr, compute_ssim, abs_relative_difference, delta1_acc |
| | from ..global_cfg import get_cfg |
| | from ..loss import Loss |
| | from ..loss.loss_point import Regr3D |
| | from ..loss.loss_ssim import ssim |
| | from ..misc.benchmarker import Benchmarker |
| | from ..misc.cam_utils import update_pose, get_pnp_pose, rotation_6d_to_matrix |
| | from ..misc.image_io import prep_image, save_image, save_video |
| | from ..misc.LocalLogger import LOG_PATH, LocalLogger |
| | from ..misc.nn_module_tools import convert_to_buffer |
| | from ..misc.step_tracker import StepTracker |
| | from ..misc.utils import inverse_normalize, vis_depth_map, confidence_map, get_overlap_tag |
| | from ..visualization.annotation import add_label |
| | from ..visualization.camera_trajectory.interpolation import ( |
| | interpolate_extrinsics, |
| | interpolate_intrinsics, |
| | ) |
| | from ..visualization.camera_trajectory.wobble import ( |
| | generate_wobble, |
| | generate_wobble_transformation, |
| | ) |
| | from ..visualization.color_map import apply_color_map_to_image |
| | from ..visualization.layout import add_border, hcat, vcat |
| | |
| | from .decoder.decoder import Decoder, DepthRenderingMode |
| | from .encoder import Encoder |
| | from .encoder.visualization.encoder_visualizer import EncoderVisualizer |
| | from .ply_export import export_ply |
| |
|
| | @dataclass |
| | class OptimizerCfg: |
| | lr: float |
| | warm_up_steps: int |
| | backbone_lr_multiplier: float |
| |
|
| |
|
| | @dataclass |
| | class TestCfg: |
| | output_path: Path |
| | align_pose: bool |
| | pose_align_steps: int |
| | rot_opt_lr: float |
| | trans_opt_lr: float |
| | compute_scores: bool |
| | save_image: bool |
| | save_video: bool |
| | save_compare: bool |
| | generate_video: bool |
| | mode: Literal["inference", "evaluation"] |
| | image_folder: str |
| |
|
| |
|
| | @dataclass |
| | class TrainCfg: |
| | output_path: Path |
| | depth_mode: DepthRenderingMode | None |
| | extended_visualization: bool |
| | print_log_every_n_steps: int |
| | distiller: str |
| | distill_max_steps: int |
| | pose_loss_alpha: float = 1.0 |
| | pose_loss_delta: float = 1.0 |
| | cxt_depth_weight: float = 0.01 |
| | weight_pose: float = 1.0 |
| | weight_depth: float = 1.0 |
| | weight_normal: float = 1.0 |
| | render_ba: bool = False |
| | render_ba_after_step: int = 0 |
| |
|
| |
|
| | @runtime_checkable |
| | class TrajectoryFn(Protocol): |
| | def __call__( |
| | self, |
| | t: Float[Tensor, " t"], |
| | ) -> tuple[ |
| | Float[Tensor, "batch view 4 4"], |
| | Float[Tensor, "batch view 3 3"], |
| | ]: |
| | pass |
| |
|
| |
|
| | class ModelWrapper(LightningModule): |
| | logger: Optional[WandbLogger] |
| | model: nn.Module |
| | losses: nn.ModuleList |
| | optimizer_cfg: OptimizerCfg |
| | test_cfg: TestCfg |
| | train_cfg: TrainCfg |
| | step_tracker: StepTracker | None |
| |
|
| | def __init__( |
| | self, |
| | optimizer_cfg: OptimizerCfg, |
| | test_cfg: TestCfg, |
| | train_cfg: TrainCfg, |
| | model: nn.Module, |
| | losses: list[Loss], |
| | step_tracker: StepTracker | None |
| | ) -> None: |
| | super().__init__() |
| | self.optimizer_cfg = optimizer_cfg |
| | self.test_cfg = test_cfg |
| | self.train_cfg = train_cfg |
| | self.step_tracker = step_tracker |
| | |
| | |
| | self.encoder_visualizer = None |
| | self.model = model |
| | self.data_shim = get_data_shim(self.model.encoder) |
| | self.losses = nn.ModuleList(losses) |
| | |
| | if self.model.encoder.pred_pose: |
| | self.loss_pose = HuberLoss(alpha=self.train_cfg.pose_loss_alpha, delta=self.train_cfg.pose_loss_delta) |
| | |
| | if self.model.encoder.distill: |
| | self.loss_distill = DistillLoss( |
| | delta=self.train_cfg.pose_loss_delta, |
| | weight_pose=self.train_cfg.weight_pose, |
| | weight_depth=self.train_cfg.weight_depth, |
| | weight_normal=self.train_cfg.weight_normal |
| | ) |
| |
|
| | |
| | self.benchmarker = Benchmarker() |
| | |
| | def on_train_epoch_start(self) -> None: |
| | |
| | if hasattr(self.trainer.datamodule.train_loader.dataset, "set_epoch"): |
| | self.trainer.datamodule.train_loader.dataset.set_epoch(self.current_epoch) |
| | if hasattr(self.trainer.datamodule.train_loader.sampler, "set_epoch"): |
| | self.trainer.datamodule.train_loader.sampler.set_epoch(self.current_epoch) |
| |
|
| | def on_validation_epoch_start(self) -> None: |
| | print(f"Validation epoch start on rank {self.trainer.global_rank}") |
| | |
| | if hasattr(self.trainer.datamodule.val_loader.dataset, "set_epoch"): |
| | self.trainer.datamodule.val_loader.dataset.set_epoch(self.current_epoch) |
| | if hasattr(self.trainer.datamodule.val_loader.sampler, "set_epoch"): |
| | self.trainer.datamodule.val_loader.sampler.set_epoch(self.current_epoch) |
| | |
| | def training_step(self, batch, batch_idx): |
| | |
| | |
| | if isinstance(batch, list): |
| | batch_combined = None |
| | for batch_per_dl in batch: |
| | if batch_combined is None: |
| | batch_combined = batch_per_dl |
| | else: |
| | for k in batch_combined.keys(): |
| | if isinstance(batch_combined[k], list): |
| | batch_combined[k] += batch_per_dl[k] |
| | elif isinstance(batch_combined[k], dict): |
| | for kk in batch_combined[k].keys(): |
| | batch_combined[k][kk] = torch.cat([batch_combined[k][kk], batch_per_dl[k][kk]], dim=0) |
| | else: |
| | raise NotImplementedError |
| | batch = batch_combined |
| | |
| | batch: BatchedExample = self.data_shim(batch) |
| | b, v, c, h, w = batch["context"]["image"].shape |
| | context_image = (batch["context"]["image"] + 1) / 2 |
| | |
| | |
| | visualization_dump = None |
| |
|
| | encoder_output, output = self.model(context_image, self.global_step, visualization_dump=visualization_dump) |
| | gaussians, pred_pose_enc_list, depth_dict = encoder_output.gaussians, encoder_output.pred_pose_enc_list, encoder_output.depth_dict |
| | pred_context_pose = encoder_output.pred_context_pose |
| | infos = encoder_output.infos |
| | distill_infos = encoder_output.distill_infos |
| | |
| | num_context_views = pred_context_pose['extrinsic'].shape[1] |
| |
|
| | using_index = torch.arange(num_context_views, device=gaussians.means.device) |
| | batch["using_index"] = using_index |
| | |
| | target_gt = (batch["context"]["image"] + 1) / 2 |
| | scene_scale = infos["scene_scale"] |
| | self.log("train/scene_scale", infos["scene_scale"]) |
| | self.log("train/voxelize_ratio", infos["voxelize_ratio"]) |
| |
|
| | |
| | psnr_probabilistic = compute_psnr( |
| | rearrange(target_gt, "b v c h w -> (b v) c h w"), |
| | rearrange(output.color, "b v c h w -> (b v) c h w"), |
| | ) |
| | self.log("train/psnr_probabilistic", psnr_probabilistic.mean()) |
| |
|
| | consis_absrel = abs_relative_difference( |
| | rearrange(output.depth, "b v h w -> (b v) h w"), |
| | rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), |
| | rearrange(distill_infos['conf_mask'], "b v h w -> (b v) h w"), |
| | ) |
| | self.log("train/consis_absrel", consis_absrel.mean()) |
| |
|
| | consis_delta1 = delta1_acc( |
| | rearrange(output.depth, "b v h w -> (b v) h w"), |
| | rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), |
| | rearrange(distill_infos['conf_mask'], "b v h w -> (b v) h w"), |
| | ) |
| | self.log("train/consis_delta1", consis_delta1.mean()) |
| | |
| | |
| | total_loss = 0 |
| |
|
| | depth_dict['distill_infos'] = distill_infos |
| | with torch.amp.autocast('cuda', enabled=False): |
| | for loss_fn in self.losses: |
| | loss = loss_fn.forward(output, batch, gaussians, depth_dict, self.global_step) |
| | self.log(f"loss/{loss_fn.name}", loss) |
| | total_loss = total_loss + loss |
| |
|
| | if depth_dict is not None and "depth" in get_cfg()["loss"].keys() and self.train_cfg.cxt_depth_weight > 0: |
| | depth_loss_idx = list(get_cfg()["loss"].keys()).index("depth") |
| | depth_loss_fn = self.losses[depth_loss_idx].ctx_depth_loss |
| | loss_depth = depth_loss_fn(depth_dict["depth_map"], depth_dict["depth_conf"], batch, cxt_depth_weight=self.train_cfg.cxt_depth_weight) |
| | self.log("loss/ctx_depth", loss_depth) |
| | total_loss = total_loss + loss_depth |
| |
|
| | if distill_infos is not None: |
| | |
| | loss_distill_list = self.loss_distill(distill_infos, pred_pose_enc_list, output, batch) |
| | self.log("loss/distill", loss_distill_list['loss_distill']) |
| | self.log("loss/distill_pose", loss_distill_list['loss_pose']) |
| | self.log("loss/distill_depth", loss_distill_list['loss_depth']) |
| | self.log("loss/distill_normal", loss_distill_list['loss_normal']) |
| | total_loss = total_loss + loss_distill_list['loss_distill'] |
| | |
| | self.log("loss/total", total_loss) |
| | print(f"total_loss: {total_loss}") |
| |
|
| | |
| | SKIP_AFTER_STEP = 1000 |
| | LOSS_THRESHOLD = 0.2 |
| | if self.global_step > SKIP_AFTER_STEP and total_loss > LOSS_THRESHOLD: |
| | print(f"Skipping batch with high loss ({total_loss:.6f}) at step {self.global_step} on Rank {self.global_rank}") |
| | |
| | return total_loss * 1e-10 |
| |
|
| | if ( |
| | self.global_rank == 0 |
| | and self.global_step % self.train_cfg.print_log_every_n_steps == 0 |
| | ): |
| | print( |
| | f"train step {self.global_step}; " |
| | f"scene = {[x[:20] for x in batch['scene']]}; " |
| | f"context = {batch['context']['index'].tolist()}; " |
| | f"loss = {total_loss:.6f}; " |
| | ) |
| | |
| | self.log("info/global_step", self.global_step) |
| | |
| | |
| | if self.step_tracker is not None: |
| | self.step_tracker.set_step(self.global_step) |
| | |
| | del batch |
| | if self.global_step % 50 == 0: |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | return total_loss |
| | |
| | def on_after_backward(self): |
| | total_norm = 0.0 |
| | counter = 0 |
| | for p in self.parameters(): |
| | if p.grad is not None: |
| | param_norm = p.grad.detach().data.norm(2) |
| | total_norm += param_norm.item() ** 2 |
| | counter += 1 |
| | total_norm = (total_norm / counter) ** 0.5 |
| | self.log("loss/grad_norm", total_norm) |
| | |
| | def test_step(self, batch, batch_idx): |
| | batch: BatchedExample = self.data_shim(batch) |
| | b, v, _, h, w = batch["target"]["image"].shape |
| | assert b == 1 |
| | if batch_idx % 100 == 0: |
| | print(f"Test step {batch_idx:0>6}.") |
| | |
| | |
| | with self.benchmarker.time("encoder"): |
| | gaussians = self.model.encoder( |
| | (batch["context"]["image"]+1)/2, |
| | self.global_step, |
| | )[0] |
| | |
| | |
| | if self.test_cfg.align_pose: |
| | output = self.test_step_align(batch, gaussians) |
| | else: |
| | with self.benchmarker.time("decoder", num_calls=v): |
| | output = self.model.decoder.forward( |
| | gaussians, |
| | batch["target"]["extrinsics"], |
| | batch["target"]["intrinsics"], |
| | batch["target"]["near"], |
| | batch["target"]["far"], |
| | (h, w), |
| | ) |
| | |
| | |
| | if self.test_cfg.compute_scores: |
| | overlap = batch["context"]["overlap"][0] |
| | overlap_tag = get_overlap_tag(overlap) |
| |
|
| | rgb_pred = output.color[0] |
| | rgb_gt = batch["target"]["image"][0] |
| | all_metrics = { |
| | f"lpips_ours": compute_lpips(rgb_gt, rgb_pred).mean(), |
| | f"ssim_ours": compute_ssim(rgb_gt, rgb_pred).mean(), |
| | f"psnr_ours": compute_psnr(rgb_gt, rgb_pred).mean(), |
| | } |
| | methods = ['ours'] |
| |
|
| | self.log_dict(all_metrics) |
| | self.print_preview_metrics(all_metrics, methods, overlap_tag=overlap_tag) |
| | |
| | |
| | (scene,) = batch["scene"] |
| | name = get_cfg()["wandb"]["name"] |
| | path = self.test_cfg.output_path / name |
| | if self.test_cfg.save_image: |
| | for index, color in zip(batch["target"]["index"][0], output.color[0]): |
| | save_image(color, path / scene / f"color/{index:0>6}.png") |
| |
|
| | if self.test_cfg.save_video: |
| | frame_str = "_".join([str(x.item()) for x in batch["context"]["index"][0]]) |
| | save_video( |
| | [a for a in output.color[0]], |
| | path / "video" / f"{scene}_frame_{frame_str}.mp4", |
| | ) |
| |
|
| | if self.test_cfg.save_compare: |
| | |
| | context_img = inverse_normalize(batch["context"]["image"][0]) |
| | comparison = hcat( |
| | add_label(vcat(*context_img), "Context"), |
| | add_label(vcat(*rgb_gt), "Target (Ground Truth)"), |
| | add_label(vcat(*rgb_pred), "Target (Prediction)"), |
| | ) |
| | save_image(comparison, path / f"{scene}.png") |
| | |
| | def test_step_align(self, batch, gaussians): |
| | self.model.encoder.eval() |
| | |
| | for param in self.model.encoder.parameters(): |
| | param.requires_grad = False |
| |
|
| | b, v, _, h, w = batch["target"]["image"].shape |
| | output_c2ws = batch["target"]["extrinsics"] |
| | with torch.set_grad_enabled(True): |
| | cam_rot_delta = nn.Parameter(torch.zeros([b, v, 6], requires_grad=True, device=output_c2ws.device)) |
| | cam_trans_delta = nn.Parameter(torch.zeros([b, v, 3], requires_grad=True, device=output_c2ws.device)) |
| | opt_params = [] |
| | self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).to(output_c2ws)) |
| | opt_params.append( |
| | { |
| | "params": [cam_rot_delta], |
| | "lr": 0.005, |
| | } |
| | ) |
| | opt_params.append( |
| | { |
| | "params": [cam_trans_delta], |
| | "lr": 0.005, |
| | } |
| | ) |
| | pose_optimizer = torch.optim.Adam(opt_params) |
| | extrinsics = output_c2ws.clone() |
| | with self.benchmarker.time("optimize"): |
| | for i in range(self.test_cfg.pose_align_steps): |
| | pose_optimizer.zero_grad() |
| | dx, drot = cam_trans_delta, cam_rot_delta |
| | rot = rotation_6d_to_matrix( |
| | drot + self.identity.expand(b, v, -1) |
| | ) |
| |
|
| | transform = torch.eye(4, device=extrinsics.device).repeat((b, v, 1, 1)) |
| | transform[..., :3, :3] = rot |
| | transform[..., :3, 3] = dx |
| |
|
| | new_extrinsics = torch.matmul(extrinsics, transform) |
| | output = self.model.decoder.forward( |
| | gaussians, |
| | new_extrinsics, |
| | batch["target"]["intrinsics"], |
| | batch["target"]["near"], |
| | batch["target"]["far"], |
| | (h, w), |
| | |
| | |
| | ) |
| |
|
| | |
| | total_loss = 0 |
| | for loss_fn in self.losses: |
| | loss = loss_fn.forward(output, batch, gaussians, self.global_step) |
| | total_loss = total_loss + loss |
| |
|
| | total_loss.backward() |
| | pose_optimizer.step() |
| | |
| | |
| | output = self.model.decoder.forward( |
| | gaussians, |
| | new_extrinsics, |
| | batch["target"]["intrinsics"], |
| | batch["target"]["near"], |
| | batch["target"]["far"], |
| | (h, w), |
| | ) |
| |
|
| | return output |
| |
|
| | def on_test_end(self) -> None: |
| | name = get_cfg()["wandb"]["name"] |
| | self.benchmarker.dump(self.test_cfg.output_path / name / "benchmark.json") |
| | self.benchmarker.dump_memory( |
| | self.test_cfg.output_path / name / "peak_memory.json" |
| | ) |
| | self.benchmarker.summarize() |
| |
|
| | @rank_zero_only |
| | def validation_step(self, batch, batch_idx, dataloader_idx=0): |
| | batch: BatchedExample = self.data_shim(batch) |
| |
|
| | if self.global_rank == 0: |
| | print( |
| | f"validation step {self.global_step}; " |
| | f"scene = {batch['scene']}; " |
| | f"context = {batch['context']['index'].tolist()}" |
| | ) |
| |
|
| | |
| | b, v, _, h, w = batch["context"]["image"].shape |
| | assert b == 1 |
| | visualization_dump = {} |
| |
|
| | encoder_output, output = self.model(batch["context"]["image"], self.global_step, visualization_dump=visualization_dump) |
| | gaussians, pred_pose_enc_list, depth_dict = encoder_output.gaussians, encoder_output.pred_pose_enc_list, encoder_output.depth_dict |
| | pred_context_pose, distill_infos = encoder_output.pred_context_pose, encoder_output.distill_infos |
| | infos = encoder_output.infos |
| |
|
| | GS_num = infos['voxelize_ratio'] * (h*w*v) |
| | self.log("val/GS_num", GS_num) |
| | |
| | num_context_views = pred_context_pose['extrinsic'].shape[1] |
| | num_target_views = batch["target"]["extrinsics"].shape[1] |
| | rgb_pred = output.color[0].float() |
| | depth_pred = vis_depth_map(output.depth[0]) |
| |
|
| | |
| | gaussian_means = visualization_dump["depth"][0].squeeze() |
| | if gaussian_means.shape[-1] == 3: |
| | gaussian_means = gaussian_means.mean(dim=-1) |
| |
|
| | |
| | rgb_gt = (batch["context"]["image"][0].float() + 1) / 2 |
| | psnr = compute_psnr(rgb_gt, rgb_pred).mean() |
| | self.log(f"val/psnr", psnr) |
| | lpips = compute_lpips(rgb_gt, rgb_pred).mean() |
| | self.log(f"val/lpips", lpips) |
| | ssim = compute_ssim(rgb_gt, rgb_pred).mean() |
| | self.log(f"val/ssim", ssim) |
| |
|
| | |
| | consis_absrel = abs_relative_difference( |
| | rearrange(output.depth, "b v h w -> (b v) h w"), |
| | rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), |
| | ) |
| | self.log("val/consis_absrel", consis_absrel.mean()) |
| | |
| | consis_delta1 = delta1_acc( |
| | rearrange(output.depth, "b v h w -> (b v) h w"), |
| | rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), |
| | valid_mask=rearrange(torch.ones_like(output.depth, device=output.depth.device, dtype=torch.bool), "b v h w -> (b v) h w"), |
| | ) |
| | self.log("val/consis_delta1", consis_delta1.mean()) |
| |
|
| | diff_map = torch.abs(output.depth - depth_dict['depth'].squeeze(-1)) |
| | self.log("val/consis_mse", diff_map[distill_infos['conf_mask']].mean()) |
| |
|
| | |
| | context_img = inverse_normalize(batch["context"]["image"][0]) |
| | |
| | context = [] |
| | for i in range(context_img.shape[0]): |
| | context.append(context_img[i]) |
| | |
| | |
| | colored_diff_map = vis_depth_map(diff_map[0], near=torch.tensor(1e-4, device=diff_map.device), far=torch.tensor(1.0, device=diff_map.device)) |
| | model_depth_pred = depth_dict["depth"].squeeze(-1)[0] |
| | model_depth_pred = vis_depth_map(model_depth_pred) |
| | |
| | render_normal = (get_normal_map(output.depth.flatten(0, 1), batch["context"]["intrinsics"].flatten(0, 1)).permute(0, 3, 1, 2) + 1) / 2. |
| | pred_normal = (get_normal_map(depth_dict['depth'].flatten(0, 1).squeeze(-1), batch["context"]["intrinsics"].flatten(0, 1)).permute(0, 3, 1, 2) + 1) / 2. |
| |
|
| | comparison = hcat( |
| | add_label(vcat(*context), "Context"), |
| | add_label(vcat(*rgb_gt), "Target (Ground Truth)"), |
| | add_label(vcat(*rgb_pred), "Target (Prediction)"), |
| | add_label(vcat(*depth_pred), "Depth (Prediction)"), |
| | add_label(vcat(*model_depth_pred), "Depth (VGGT Prediction)"), |
| | add_label(vcat(*render_normal), "Normal (Prediction)"), |
| | add_label(vcat(*pred_normal), "Normal (VGGT Prediction)"), |
| | add_label(vcat(*colored_diff_map), "Diff Map"), |
| | ) |
| |
|
| | comparison = torch.nn.functional.interpolate( |
| | comparison.unsqueeze(0), |
| | scale_factor=0.5, |
| | mode='bicubic', |
| | align_corners=False |
| | ).squeeze(0) |
| | |
| | self.logger.log_image( |
| | "comparison", |
| | [prep_image(add_border(comparison))], |
| | step=self.global_step, |
| | caption=batch["scene"], |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | if self.encoder_visualizer is not None: |
| | for k, image in self.encoder_visualizer.visualize( |
| | batch["context"], self.global_step |
| | ).items(): |
| | self.logger.log_image(k, [prep_image(image)], step=self.global_step) |
| | |
| | |
| | self.render_video_interpolation(batch) |
| | self.render_video_wobble(batch) |
| | if self.train_cfg.extended_visualization: |
| | self.render_video_interpolation_exaggerated(batch) |
| |
|
| | @rank_zero_only |
| | def render_video_wobble(self, batch: BatchedExample) -> None: |
| | |
| | _, v, _, _ = batch["context"]["extrinsics"].shape |
| | if v != 2: |
| | return |
| |
|
| | def trajectory_fn(t): |
| | origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] |
| | origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] |
| | delta = (origin_a - origin_b).norm(dim=-1) |
| | extrinsics = generate_wobble( |
| | batch["context"]["extrinsics"][:, 0], |
| | delta * 0.25, |
| | t, |
| | ) |
| | intrinsics = repeat( |
| | batch["context"]["intrinsics"][:, 0], |
| | "b i j -> b v i j", |
| | v=t.shape[0], |
| | ) |
| | return extrinsics, intrinsics |
| |
|
| | return self.render_video_generic(batch, trajectory_fn, "wobble", num_frames=60) |
| |
|
| | @rank_zero_only |
| | def render_video_interpolation(self, batch: BatchedExample) -> None: |
| | _, v, _, _ = batch["context"]["extrinsics"].shape |
| |
|
| | def trajectory_fn(t): |
| | extrinsics = interpolate_extrinsics( |
| | batch["context"]["extrinsics"][0, 0], |
| | ( |
| | batch["context"]["extrinsics"][0, 1] |
| | if v == 2 |
| | else batch["target"]["extrinsics"][0, 0] |
| | ), |
| | t, |
| | ) |
| | intrinsics = interpolate_intrinsics( |
| | batch["context"]["intrinsics"][0, 0], |
| | ( |
| | batch["context"]["intrinsics"][0, 1] |
| | if v == 2 |
| | else batch["target"]["intrinsics"][0, 0] |
| | ), |
| | t, |
| | ) |
| | return extrinsics[None], intrinsics[None] |
| |
|
| | return self.render_video_generic(batch, trajectory_fn, "rgb") |
| |
|
| | @rank_zero_only |
| | def render_video_interpolation_exaggerated(self, batch: BatchedExample) -> None: |
| | |
| | _, v, _, _ = batch["context"]["extrinsics"].shape |
| | if v != 2: |
| | return |
| |
|
| | def trajectory_fn(t): |
| | origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] |
| | origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] |
| | delta = (origin_a - origin_b).norm(dim=-1) |
| | tf = generate_wobble_transformation( |
| | delta * 0.5, |
| | t, |
| | 5, |
| | scale_radius_with_t=False, |
| | ) |
| | extrinsics = interpolate_extrinsics( |
| | batch["context"]["extrinsics"][0, 0], |
| | ( |
| | batch["context"]["extrinsics"][0, 1] |
| | if v == 2 |
| | else batch["target"]["extrinsics"][0, 0] |
| | ), |
| | t * 5 - 2, |
| | ) |
| | intrinsics = interpolate_intrinsics( |
| | batch["context"]["intrinsics"][0, 0], |
| | ( |
| | batch["context"]["intrinsics"][0, 1] |
| | if v == 2 |
| | else batch["target"]["intrinsics"][0, 0] |
| | ), |
| | t * 5 - 2, |
| | ) |
| | return extrinsics @ tf, intrinsics[None] |
| |
|
| | return self.render_video_generic( |
| | batch, |
| | trajectory_fn, |
| | "interpolation_exagerrated", |
| | num_frames=300, |
| | smooth=False, |
| | loop_reverse=False, |
| | ) |
| |
|
| | @rank_zero_only |
| | def render_video_generic( |
| | self, |
| | batch: BatchedExample, |
| | trajectory_fn: TrajectoryFn, |
| | name: str, |
| | num_frames: int = 30, |
| | smooth: bool = True, |
| | loop_reverse: bool = True, |
| | ) -> None: |
| | |
| | encoder_output = self.model.encoder((batch["context"]["image"]+1)/2, self.global_step) |
| | gaussians, pred_pose_enc_list = encoder_output.gaussians, encoder_output.pred_pose_enc_list |
| |
|
| | t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=self.device) |
| | if smooth: |
| | t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 |
| |
|
| | extrinsics, intrinsics = trajectory_fn(t) |
| |
|
| | _, _, _, h, w = batch["context"]["image"].shape |
| |
|
| | |
| | near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames) |
| | far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames) |
| | output = self.model.decoder.forward( |
| | gaussians, extrinsics, intrinsics, near, far, (h, w), "depth" |
| | ) |
| | images = [ |
| | vcat(rgb, depth) |
| | for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0])) |
| | ] |
| |
|
| | video = torch.stack(images) |
| | video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy() |
| | if loop_reverse: |
| | video = pack([video, video[::-1][1:-1]], "* c h w")[0] |
| | visualizations = { |
| | f"video/{name}": wandb.Video(video[None], fps=30, format="mp4") |
| | } |
| | |
| | |
| | try: |
| | wandb.log(visualizations) |
| | except Exception: |
| | assert isinstance(self.logger, LocalLogger) |
| | for key, value in visualizations.items(): |
| | tensor = value._prepare_video(value.data) |
| | clip = mpy.ImageSequenceClip(list(tensor), fps=30) |
| | dir = LOG_PATH / key |
| | dir.mkdir(exist_ok=True, parents=True) |
| | clip.write_videofile( |
| | str(dir / f"{self.global_step:0>6}.mp4"), logger=None |
| | ) |
| |
|
| | def print_preview_metrics(self, metrics: dict[str, float | Tensor], methods: list[str] | None = None, overlap_tag: str | None = None) -> None: |
| | if getattr(self, "running_metrics", None) is None: |
| | self.running_metrics = metrics |
| | self.running_metric_steps = 1 |
| | else: |
| | s = self.running_metric_steps |
| | self.running_metrics = { |
| | k: ((s * v) + metrics[k]) / (s + 1) |
| | for k, v in self.running_metrics.items() |
| | } |
| | self.running_metric_steps += 1 |
| |
|
| | if overlap_tag is not None: |
| | if getattr(self, "running_metrics_sub", None) is None: |
| | self.running_metrics_sub = {overlap_tag: metrics} |
| | self.running_metric_steps_sub = {overlap_tag: 1} |
| | elif overlap_tag not in self.running_metrics_sub: |
| | self.running_metrics_sub[overlap_tag] = metrics |
| | self.running_metric_steps_sub[overlap_tag] = 1 |
| | else: |
| | s = self.running_metric_steps_sub[overlap_tag] |
| | self.running_metrics_sub[overlap_tag] = {k: ((s * v) + metrics[k]) / (s + 1) |
| | for k, v in self.running_metrics_sub[overlap_tag].items()} |
| | self.running_metric_steps_sub[overlap_tag] += 1 |
| |
|
| | metric_list = ["psnr", "lpips", "ssim"] |
| |
|
| | def print_metrics(runing_metric, methods=None): |
| | table = [] |
| | if methods is None: |
| | methods = ['ours'] |
| |
|
| | for method in methods: |
| | row = [ |
| | f"{runing_metric[f'{metric}_{method}']:.3f}" |
| | for metric in metric_list |
| | ] |
| | table.append((method, *row)) |
| |
|
| | headers = ["Method"] + metric_list |
| | table = tabulate(table, headers) |
| | print(table) |
| |
|
| | print("All Pairs:") |
| | print_metrics(self.running_metrics, methods) |
| | if overlap_tag is not None: |
| | for k, v in self.running_metrics_sub.items(): |
| | print(f"Overlap: {k}") |
| | print_metrics(v, methods) |
| |
|
| | def configure_optimizers(self): |
| | new_params, new_param_names = [], [] |
| | pretrained_params, pretrained_param_names = [], [] |
| | for name, param in self.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| | |
| | if "gaussian_param_head" in name or "interm" in name: |
| | new_params.append(param) |
| | new_param_names.append(name) |
| | else: |
| | pretrained_params.append(param) |
| | pretrained_param_names.append(name) |
| | |
| | param_dicts = [ |
| | { |
| | "params": new_params, |
| | "lr": self.optimizer_cfg.lr, |
| | }, |
| | { |
| | "params": pretrained_params, |
| | "lr": self.optimizer_cfg.lr * self.optimizer_cfg.backbone_lr_multiplier, |
| | }, |
| | ] |
| | optimizer = torch.optim.AdamW(param_dicts, lr=self.optimizer_cfg.lr, weight_decay=0.05, betas=(0.9, 0.95)) |
| | warm_up_steps = self.optimizer_cfg.warm_up_steps |
| | warm_up = torch.optim.lr_scheduler.LinearLR( |
| | optimizer, |
| | 1 / warm_up_steps, |
| | 1, |
| | total_iters=warm_up_steps, |
| | ) |
| | |
| | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=get_cfg()["trainer"]["max_steps"], eta_min=self.optimizer_cfg.lr * 0.1) |
| | lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warm_up, lr_scheduler], milestones=[warm_up_steps]) |
| |
|
| | return { |
| | "optimizer": optimizer, |
| | "lr_scheduler": { |
| | "scheduler": lr_scheduler, |
| | "interval": "step", |
| | "frequency": 1, |
| | }, |
| | } |
| |
|