from dataclasses import dataclass from pathlib import Path from typing import Optional, Protocol, runtime_checkable import moviepy.editor as mpy import torch # import wandb import swanlab as wandb from einops import pack, rearrange, repeat, einsum from jaxtyping import Float from pytorch_lightning import LightningModule # from pytorch_lightning.loggers.wandb import WandbLogger from swanlab.integration.pytorch_lightning import SwanLabLogger from pytorch_lightning.utilities import rank_zero_only from torch import Tensor, nn, optim import numpy as np import json import os import time from tqdm import tqdm import torch.nn.functional as F import math from ..dataset.data_module import get_data_shim from ..dataset.types import BatchedExample from ..dataset import DatasetCfg from ..evaluation.metrics import compute_lpips, compute_psnr, compute_ssim from ..global_cfg import get_cfg from ..loss import Loss from ..misc.benchmarker import Benchmarker from ..misc.image_io import prep_image, save_image, save_video from ..misc.LocalLogger import LOG_PATH, LocalLogger from ..misc.step_tracker import StepTracker from ..visualization.annotation import add_label from ..visualization.camera_trajectory.interpolation import ( interpolate_extrinsics, interpolate_intrinsics, ) from ..visualization.camera_trajectory.wobble import ( generate_wobble, generate_wobble_transformation, ) from ..visualization.color_map import apply_color_map_to_image from ..visualization.layout import add_border, hcat, vcat from ..visualization.validation_in_3d import render_cameras, render_projections from .decoder.decoder import Decoder, DepthRenderingMode, DecoderOutput from .encoder import Encoder from .encoder.visualization.encoder_visualizer import EncoderVisualizer from src.visualization.vis_depth import viz_depth_tensor from PIL import Image from ..misc.stablize_camera import render_stabilization_path from .ply_export import save_gaussian_ply from .encoder.encoder_volsplat import print_mem import MinkowskiEngine as ME ###测试### from ..test.visual import save_output_images from ..test.export_ply import export_raw_points_step # import debugpy # try: # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1 # debugpy.listen(("localhost", 9326)) # print("Waiting for debugger attach") # debugpy.wait_for_client() # except Exception as e: # pass @dataclass class OptimizerCfg: lr: float warm_up_steps: int lr_monodepth: float weight_decay: float @dataclass class TestCfg: output_path: Path compute_scores: bool save_image: bool save_video: bool eval_time_skip_steps: int save_gt_image: bool save_input_images: bool save_depth: bool save_depth_concat_img: bool save_depth_npy: bool save_gaussian: bool render_chunk_size: int | None stablize_camera: bool stab_camera_kernel: int @dataclass class TrainCfg: depth_mode: DepthRenderingMode | None extended_visualization: bool print_log_every_n_steps: int eval_model_every_n_val: int eval_data_length: int eval_deterministic: bool eval_time_skip_steps: int eval_save_model: bool l1_loss: bool intermediate_loss_weight: float no_viz_video: bool viz_depth: bool forward_depth_only: bool train_ignore_large_loss: float no_log_projections: bool @runtime_checkable class TrajectoryFn(Protocol): def __call__( self, t: Float[Tensor, " t"], ) -> tuple[ Float[Tensor, "batch view 4 4"], # extrinsics Float[Tensor, "batch view 3 3"], # intrinsics ]: pass class ModelWrapper(LightningModule): logger: Optional[SwanLabLogger] encoder: nn.Module encoder_visualizer: Optional[EncoderVisualizer] decoder: Decoder losses: nn.ModuleList optimizer_cfg: OptimizerCfg test_cfg: TestCfg train_cfg: TrainCfg step_tracker: StepTracker | None eval_data_cfg: Optional[DatasetCfg | None] def __init__( self, optimizer_cfg: OptimizerCfg, test_cfg: TestCfg, train_cfg: TrainCfg, encoder: Encoder, encoder_visualizer: Optional[EncoderVisualizer], decoder: Decoder, losses: list[Loss], step_tracker: StepTracker | None, eval_data_cfg: Optional[DatasetCfg | None] = None, ) -> None: super().__init__() self.optimizer_cfg = optimizer_cfg self.test_cfg = test_cfg self.train_cfg = train_cfg self.step_tracker = step_tracker self.eval_data_cfg = eval_data_cfg # Set up the model. self.encoder = encoder self.encoder_visualizer = encoder_visualizer self.decoder = decoder self.data_shim = get_data_shim(self.encoder) self.losses = nn.ModuleList(losses) # This is used for testing. self.benchmarker = Benchmarker() self.eval_cnt = 0 if self.test_cfg.compute_scores: self.test_step_outputs = {} self.time_skip_steps_dict = {"encoder": 0, "decoder": 0} # This is used for testing. self.benchmarker = Benchmarker() self.eval_cnt = 0 if self.test_cfg.compute_scores: self.test_step_outputs = {} self.time_skip_steps_dict = {"encoder": 0, "decoder": 0} # MOD: debug开关(默认False)。要启用参数更新检查,在训练脚本里设置 model._check_param_updates = True self._check_param_updates = False # MOD: 保存训练 step 开始时的参数快照(仅在_check_param_updates=True时填充) self._before_params_snapshot = None def training_step(self, batch, batch_idx): batch: BatchedExample = self.data_shim(batch) _, _, _, h, w = batch["target"]["image"].shape _, views, _, _, _ = batch["context"]["image"].shape # MOD: 如果开启了参数更新检查,保存当前参数快照(用于在 optimizer.step 后对比) if getattr(self, "_check_param_updates", False): try: self._before_params_snapshot = self.snapshot_params() except Exception as e: print("[DEBUG] failed to snapshot params:", e) #######打印场景ID######## print(f"Training step{self.global_step},Number of images:{views}:scene IDs:{batch['scene']}") # if self.global_step < 5000: # ues_voxelnet = True # else: # ues_voxelnet = True print_mem("before encoder") # Run the model. #ues_voxelnet=ues_voxelnet gaussians = self.encoder( batch["context"], self.global_step, False, scene_names=batch["scene"] ) print_mem("after encoder") if isinstance(gaussians, dict) and len(gaussians) == 2: supervise_intermediate_depth = False pred_depths = gaussians["depths"] gaussians = gaussians["gaussians"] if isinstance(gaussians, dict) and len(gaussians) == 3: supervise_intermediate_depth = True pred_depths = gaussians["depths"] intermediate_gaussians = gaussians["intermediate_gaussians"] gaussians = gaussians["gaussians"] # ------------------ < 新增:检查 pred_depths 是否全为 0,并中断训练 > ----------------- try: with torch.no_grad(): pd = pred_depths.detach() # 期望 pred_depths shape = [B, V, H, W] B = pd.shape[0] # 若存在 NaN 或 非数值项,先转换为 0(避免 sum 产生 NaN) pd = torch.nan_to_num(pd, nan=0.0, posinf=0.0, neginf=0.0) per_sample_sum = pd.abs().view(B, -1).sum(dim=1) zero_mask = (per_sample_sum == 0) if zero_mask.any(): zero_idx = torch.nonzero(zero_mask, as_tuple=False).squeeze(1) # 处理 single-element 情况,统一为 list[int] if zero_idx.numel() == 1: zero_idx = [int(zero_idx.item())] else: zero_idx = [int(i.item()) for i in zero_idx] # 提取对应的 scene id(兼容 list/tuple/tensor/string 等多种形式) scene_ids = [] for i in zero_idx: try: s = batch["scene"][i] except Exception: # 如果索引失败,跳过 s = None # 将各种类型转换为可打印的 python 值 try: if isinstance(s, torch.Tensor): # 若是单值 tensor if s.numel() == 1: scene_ids.append(s.item()) else: # 向量 tensor,转为 list scene_ids.append(s.cpu().tolist()) else: scene_ids.append(s) except Exception: scene_ids.append(str(s)) # 多卡/分布式训练中只由 rank0 打印,避免重复日志 is_rank0 = True try: is_rank0 = getattr(self.trainer, "global_rank", getattr(self, "global_rank", 0)) == 0 except Exception: pass if is_rank0: print(f"[STOPPING] pred_depths all-zero detected for batch indices {zero_idx}; scene IDs: {scene_ids}") # 方式 A:设置 trainer 的停止标志(较为优雅) try: self.trainer.should_stop = True except Exception: pass # 方式 B:立即抛出异常以立刻中断训练(根据你的需求,选择保留或注释) raise RuntimeError(f"Stopping training because pred_depths are all zero for scenes: {scene_ids}") except Exception as e: # 如果检测逻辑本身出错,则打印 debug 信息并继续训练(以避免无意中挂起训练) print("[DEBUG] pred_depths zero-check failed or triggered stop. info:", e) # 如果你希望检测失败时也中断训练,可以在这里改为 raise # ------------------ < 新增结束 > ------------------ print_mem("before decoder") #############设置中间深度监督################### if gaussians.means.size(0) != batch["target"]["extrinsics"].size(0): supervise_intermediate_depth = True assert gaussians.means.size(0) % batch["target"]["extrinsics"].size(0) == 0 num_depths = gaussians.means.size(0) // batch["target"]["extrinsics"].size( 0 ) # add loss to intermediate depth predictions target_extrinsics = torch.cat( [batch["target"]["extrinsics"]] * num_depths, dim=0 ) target_intrinsics = torch.cat( [batch["target"]["intrinsics"]] * num_depths, dim=0 ) target_near = torch.cat([batch["target"]["near"]] * num_depths, dim=0) target_far = torch.cat([batch["target"]["far"]] * num_depths, dim=0) #[2B, V, 3, 256, 448] output_all = self.decoder.forward( gaussians, target_extrinsics, target_intrinsics, target_near, target_far, (h, w), depth_mode=self.train_cfg.depth_mode, ) # split batch_size = batch["target"]["extrinsics"].size(0) # order: intermediate depth, final depth output_intermediate = DecoderOutput( color=output_all.color[:-batch_size], #[B, V, 3, H, W] depth=( output_all.depth[:-batch_size] if output_all.depth is not None else None ), ) output = DecoderOutput( color=output_all.color[-batch_size:], #[B, V, 3, H, W] depth=( output_all.depth[-batch_size:] if output_all.depth is not None else None ), ) else: output = self.decoder.forward( gaussians, batch["target"]["extrinsics"], batch["target"]["intrinsics"], batch["target"]["near"], batch["target"]["far"], (h, w), depth_mode=self.train_cfg.depth_mode, ) ############################ if supervise_intermediate_depth: output_intermediate = self.decoder.forward( intermediate_gaussians, batch["target"]["extrinsics"], batch["target"]["intrinsics"], batch["target"]["near"], batch["target"]["far"], (h, w), depth_mode=self.train_cfg.depth_mode, ) print_mem("after decoder") target_gt = batch["target"]["image"] # Compute metrics. psnr_probabilistic = compute_psnr( rearrange(target_gt, "b v c h w -> (b v) c h w"), rearrange(output.color, "b v c h w -> (b v) c h w"), ) #把图片打印出来看看效果 # save_output_images(output.color, save_dir="/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/out_image", prefix="prob_output") # save_output_images(target_gt, "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/outputs/rgb_image") self.log("train/psnr", psnr_probabilistic.mean()) # Compute and log loss. total_loss = 0 valid_depth_mask = None for loss_fn in self.losses: if loss_fn.name == "mse": loss = loss_fn.forward( output, batch, gaussians, self.global_step, l1_loss=self.train_cfg.l1_loss, clamp_large_error=self.train_cfg.train_ignore_large_loss, valid_depth_mask=valid_depth_mask, ) else: loss = loss_fn.forward( output, batch, gaussians, self.global_step, valid_depth_mask=valid_depth_mask, ) self.log(f"loss/{loss_fn.name}", loss) total_loss = total_loss + loss # color loss on intermediate output if supervise_intermediate_depth: for loss_fn in self.losses: batch_size = batch["target"]["extrinsics"].size(0) if output_intermediate.color.size(0) != batch_size: assert output_intermediate.color.size(0) % batch_size == 0 num_intermediate = output_intermediate.color.size(0) // batch_size intermediate_loss = 0 for i in range(num_intermediate): curr_output = DecoderOutput( color=output_intermediate.color[ (batch_size * i) : (batch_size * (i + 1)) ], depth=( output_intermediate.depth[ (batch_size * i) : (batch_size * (i + 1)) ] if output_intermediate.depth is not None else None ), ) curr_loss_weight = self.train_cfg.intermediate_loss_weight ** ( num_intermediate - i ) if loss_fn.name == "mse": loss = loss_fn.forward( curr_output, batch, gaussians, self.global_step, l1_loss=self.train_cfg.l1_loss, clamp_large_error=self.train_cfg.train_ignore_large_loss, valid_depth_mask=valid_depth_mask, ) else: loss = loss_fn.forward( curr_output, batch, gaussians, self.global_step, valid_depth_mask=valid_depth_mask, ) intermediate_loss = intermediate_loss + curr_loss_weight * loss self.log(f"loss/{loss_fn.name}_intermediate", intermediate_loss) total_loss = total_loss + intermediate_loss else: if loss_fn.name == "mse": loss = loss_fn.forward( output_intermediate, batch, gaussians, self.global_step, l1_loss=self.train_cfg.l1_loss, clamp_large_error=self.train_cfg.train_ignore_large_loss, valid_depth_mask=valid_depth_mask, ) else: loss = loss_fn.forward( output_intermediate, batch, gaussians, self.global_step, valid_depth_mask=valid_depth_mask, ) self.log(f"loss/{loss_fn.name}_intermediate", loss) total_loss = ( total_loss + self.train_cfg.intermediate_loss_weight * loss ) self.log("loss/total", total_loss) if ( self.global_rank == 0 and self.global_step % self.train_cfg.print_log_every_n_steps == 0 ): print( f"train step {self.global_step}; " f"scene = {[x[:20] for x in batch['scene']]}; " f"context = {batch['context']['index'].tolist()}; " f"bound = [{batch['context']['near'].detach().cpu().numpy().mean()} " f"{batch['context']['far'].detach().cpu().numpy().mean()}]; " f"loss = {total_loss:.6f}" ) self.log("info/near", batch["context"]["near"].detach().cpu().numpy().mean()) self.log("info/far", batch["context"]["far"].detach().cpu().numpy().mean()) self.log("info/global_step", self.global_step) # hack for ckpt monitor # Tell the data loader processes about the current step. if self.step_tracker is not None: self.step_tracker.set_step(self.global_step) if self.global_step == 5 and self.global_rank == 0: os.system("nvidia-smi") return total_loss def test_step(self, batch, batch_idx): batch: BatchedExample = self.data_shim(batch) b, v, _, h, w = batch["target"]["image"].shape assert b == 1 pred_depths = None # save input views for visualization if self.test_cfg.save_input_images: (scene,) = batch["scene"] self.test_cfg.output_path = os.path.join(get_cfg()["output_dir"], "metrics") path = Path(get_cfg()["output_dir"]) input_images = batch["context"]["image"][0] # [V, 3, H, W] index = batch["context"]["index"][0] for idx, color in zip(index, input_images): save_image(color, path / "images" / scene / f"color/input_{idx:0>6}.png") # save depth vis if self.test_cfg.save_depth or self.test_cfg.save_gaussian: visualization_dump = {} else: visualization_dump = None # Render Gaussians. with self.benchmarker.time("encoder"): gaussians = self.encoder( batch["context"], self.global_step, deterministic=False, visualization_dump=visualization_dump, ) if isinstance(gaussians, dict): pred_depths = gaussians["depths"] if "depth" in batch["context"]: depth_gt = batch["context"]["depth"] gaussians = gaussians["gaussians"] # save gaussians if self.test_cfg.save_gaussian: scene = batch["scene"][0] save_path = Path(get_cfg()['output_dir']) / 'gaussians' / (scene + '.ply') save_gaussian_ply(gaussians, visualization_dump, batch, save_path) if not self.train_cfg.forward_depth_only: with self.benchmarker.time("decoder", num_calls=v): camera_poses = batch["target"]["extrinsics"] if self.test_cfg.stablize_camera: stable_poses = render_stabilization_path( camera_poses[0].detach().cpu().numpy(), k_size=self.test_cfg.stab_camera_kernel, ) stable_poses = list( map( lambda x: np.concatenate( (x, np.array([[0.0, 0.0, 0.0, 1.0]])), axis=0 ), stable_poses, ) ) stable_poses = torch.from_numpy(np.stack(stable_poses, axis=0)).to( camera_poses ) camera_poses = stable_poses.unsqueeze(0) if self.test_cfg.render_chunk_size is not None: chunk_size = self.test_cfg.render_chunk_size num_chunks = math.ceil(camera_poses.shape[1] / chunk_size) output = None for i in range(num_chunks): start = chunk_size * i end = chunk_size * (i + 1) render_intrinsics = batch["target"]["intrinsics"] render_near = batch["target"]["near"] render_far = batch["target"]["far"] curr_output = self.decoder.forward( gaussians, camera_poses[:, start:end], render_intrinsics[:, start:end], render_near[:, start:end], render_far[:, start:end], (h, w), depth_mode=None, ) if i == 0: output = curr_output else: # ignore depth output.color = torch.cat( (output.color, curr_output.color), dim=1 ) else: output = self.decoder.forward( gaussians, camera_poses, batch["target"]["intrinsics"], batch["target"]["near"], batch["target"]["far"], (h, w), depth_mode=None, ) (scene,) = batch["scene"] self.test_cfg.output_path = os.path.join(get_cfg()["output_dir"], "metrics") path = Path(get_cfg()["output_dir"]) # save depth if self.test_cfg.save_depth: if self.train_cfg.forward_depth_only: depth = pred_depths[0].cpu().detach() # [V, H, W] else: depth = ( visualization_dump["depth"][0, :, :, :, 0, 0].cpu().detach() ) # [V, H, W] index = batch["context"]["index"][0] if self.test_cfg.save_depth_concat_img: # concat (img0, img1, depth0, depth1) image = batch['context']['image'][0] # [V, 3, H, W] in [0,1] image = rearrange(image, "b c h w -> h (b w) c") # [H, VW, 3] image_concat = (image.detach().cpu().numpy() * 255).astype(np.uint8) # [H, VW, 3] depth_concat = [] for idx, depth_i in zip(index, depth): depth_viz = viz_depth_tensor( 1.0 / depth_i, return_numpy=True ) # [H, W, 3] if self.test_cfg.save_depth_concat_img: depth_concat.append(depth_viz) save_path = path / "images" / scene / "depth" / f"{idx:0>6}.png" save_dir = os.path.dirname(save_path) os.makedirs(save_dir, exist_ok=True) Image.fromarray(depth_viz).save(save_path) # save depth as npy if self.test_cfg.save_depth_npy: depth_npy = depth_i.detach().cpu().numpy() save_path = path / "images" / scene / "depth" / f"{idx:0>6}.npy" save_dir = os.path.dirname(save_path) os.makedirs(save_dir, exist_ok=True) np.save(save_path, depth_npy) if self.test_cfg.save_depth_concat_img: depth_concat = np.concatenate(depth_concat, axis=1) # [H, VW, 3] concat = np.concatenate((image_concat, depth_concat), axis=0) # [2H, VW, 3] save_path = path / "images" / scene / "depth" / f"img_depth_{scene}.png" save_dir = os.path.dirname(save_path) os.makedirs(save_dir, exist_ok=True) Image.fromarray(concat).save(save_path) if self.train_cfg.forward_depth_only: return images_prob = output.color[0] rgb_gt = batch["target"]["image"][0] # Save images. if self.test_cfg.save_image: if self.test_cfg.save_gt_image: for index, color, gt in zip( batch["target"]["index"][0], images_prob, rgb_gt ): save_image(color, path / "images" / scene / f"color/{index:0>6}.png") save_image(gt, path / "images" / scene / f"color/{index:0>6}_gt.png") else: for index, color in zip(batch["target"]["index"][0], images_prob): save_image(color, path / "images" / scene / f"color/{index:0>6}.png") # save video if self.test_cfg.save_video: frame_str = "_".join([str(x.item()) for x in batch["context"]["index"][0]]) save_video( [a for a in images_prob], path / "videos" / f"{scene}_frame_{frame_str}.mp4", ) # compute scores if self.test_cfg.compute_scores: if batch_idx < self.test_cfg.eval_time_skip_steps: self.time_skip_steps_dict["encoder"] += 1 self.time_skip_steps_dict["decoder"] += v if not self.train_cfg.forward_depth_only: rgb = images_prob if f"psnr" not in self.test_step_outputs: self.test_step_outputs[f"psnr"] = [] if f"ssim" not in self.test_step_outputs: self.test_step_outputs[f"ssim"] = [] if f"lpips" not in self.test_step_outputs: self.test_step_outputs[f"lpips"] = [] self.test_step_outputs[f"psnr"].append( compute_psnr(rgb_gt, rgb).mean().item() ) self.test_step_outputs[f"ssim"].append( compute_ssim(rgb_gt, rgb).mean().item() ) self.test_step_outputs[f"lpips"].append( compute_lpips(rgb_gt, rgb).mean().item() ) def on_test_end(self) -> None: out_dir = Path(self.test_cfg.output_path) saved_scores = {} if self.test_cfg.compute_scores: self.benchmarker.dump_memory(out_dir / "peak_memory.json") self.benchmarker.dump(out_dir / "benchmark.json") for metric_name, metric_scores in self.test_step_outputs.items(): avg_scores = sum(metric_scores) / len(metric_scores) saved_scores[metric_name] = avg_scores print(metric_name, avg_scores) with (out_dir / f"scores_{metric_name}_all.json").open("w") as f: json.dump(metric_scores, f) metric_scores.clear() for tag, times in self.benchmarker.execution_times.items(): times = times[int(self.time_skip_steps_dict[tag]) :] saved_scores[tag] = [len(times), np.mean(times)] print( f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call" ) self.time_skip_steps_dict[tag] = 0 with (out_dir / f"scores_all_avg.json").open("w") as f: json.dump(saved_scores, f) self.benchmarker.clear_history() else: self.benchmarker.dump(out_dir / "benchmark.json") self.benchmarker.dump_memory(out_dir / "peak_memory.json") self.benchmarker.summarize() @rank_zero_only def validation_step(self, batch, batch_idx): batch: BatchedExample = self.data_shim(batch) if self.global_rank == 0: print( f"validation step {self.global_step}; " f"scene = {[a[:20] for a in batch['scene']]}; " f"context = {batch['context']['index'].tolist()}" ) # Render Gaussians. b, _, _, h, w = batch["target"]["image"].shape assert b == 1 gaussians_softmax = self.encoder( batch["context"], self.global_step, deterministic=False, ) pred_depths = None if isinstance(gaussians_softmax, dict): pred_depths = gaussians_softmax["depths"] if "depth" in batch["context"]: depth_gt = batch["context"]["depth"] # [B, V, H, W] gaussians_softmax = gaussians_softmax["gaussians"] if not self.train_cfg.forward_depth_only: output_softmax = self.decoder.forward( gaussians_softmax, batch["target"]["extrinsics"], batch["target"]["intrinsics"], batch["target"]["near"], batch["target"]["far"], (h, w), ) rgb_softmax = output_softmax.color[0] # Compute validation metrics. rgb_gt = batch["target"]["image"][0] for tag, rgb in zip(("val",), (rgb_softmax,)): psnr = compute_psnr(rgb_gt, rgb).mean() self.log(f"val/psnr_{tag}", psnr) lpips = compute_lpips(rgb_gt, rgb).mean() self.log(f"val/lpips_{tag}", lpips) ssim = compute_ssim(rgb_gt, rgb).mean() self.log(f"val/ssim_{tag}", ssim) # viz depth if pred_depths is not None: # only visualize predicted depth pred_depths = pred_depths[0] # [V, H, W] # gaussian downsample if pred_depths.shape[1:] != batch["context"]["image"].shape[-2:]: pred_depths = F.interpolate( pred_depths.unsqueeze(1), size=batch["context"]["image"].shape[-2:], mode="bilinear", align_corners=True, ).squeeze(1) inverse_depth_pred = 1.0 / pred_depths concat = [] for i in range(inverse_depth_pred.size(0)): concat.append(inverse_depth_pred[i]) concat = torch.cat(concat, dim=1) # [H, W*N] depth_viz = viz_depth_tensor(concat.cpu().detach()) # [3, H, W*N] # also concat images input_images = batch["context"]["image"][0] # [N, 3, H, W] concat_img = [img for img in input_images] concat_img = torch.cat(concat_img, dim=-1) * 255 # [3, H, W*N] concat = torch.cat( (concat_img.cpu().detach(), depth_viz), dim=1 ) # [3, H*2, W*N] self.logger.log_image( "depth", [concat], step=self.global_step, caption=batch["scene"], ) if not self.train_cfg.forward_depth_only: # Construct comparison image. comparison = hcat( add_label(vcat(*batch["context"]["image"][0]), "Context"), add_label(vcat(*rgb_gt), "Target (Ground Truth)"), add_label(vcat(*rgb_softmax), "Target (Prediction)"), # add_label(vcat(batch["context"]["image"][0]), "Context"), # add_label(vcat(rgb_gt), "Target (Ground Truth)"), # add_label(vcat(rgb_softmax), "Target (Prediction)"), ) self.logger.log_image( "comparison", [prep_image(add_border(comparison))], step=self.global_step, caption=batch["scene"], ) if not self.train_cfg.no_log_projections: # Render projections and construct projection image. projections = hcat( *render_projections( gaussians_softmax, 256, extra_label="(Prediction)", )[0] ) self.logger.log_image( "projection", [prep_image(add_border(projections))], step=self.global_step, ) # Draw cameras. cameras = hcat(*render_cameras(batch, 256)) self.logger.log_image( "cameras", [prep_image(add_border(cameras))], step=self.global_step ) if self.encoder_visualizer is not None: for k, image in self.encoder_visualizer.visualize( batch["context"], self.global_step ).items(): self.logger.log_image(k, [prep_image(image)], step=self.global_step) # Run video validation step. if not self.train_cfg.no_viz_video: self.render_video_interpolation(batch) self.render_video_wobble(batch) if self.train_cfg.extended_visualization: self.render_video_interpolation_exaggerated(batch) def on_validation_epoch_end(self) -> None: """hack to run the full validation""" if self.trainer.sanity_checking and self.global_rank == 0: print(self.encoder) # log the model to wandb log files if (not self.trainer.sanity_checking) and (self.eval_data_cfg is not None): self.eval_cnt = self.eval_cnt + 1 if self.eval_cnt % self.train_cfg.eval_model_every_n_val == 0: # backup current ckpt before running full test sets eval if self.train_cfg.eval_save_model: ckpt_saved_path = ( self.trainer.checkpoint_callback.format_checkpoint_name( dict( epoch=self.trainer.current_epoch, step=self.trainer.global_step, ) ) ) backup_dir = str( Path(ckpt_saved_path).parent.parent / "checkpoints_backups" ) if self.global_rank == 0: os.makedirs(backup_dir, exist_ok=True) ckpt_saved_path = os.path.join( backup_dir, os.path.basename(ckpt_saved_path) ) # call save_checkpoint on ALL process as suggested by pytorch_lightning self.trainer.save_checkpoint( ckpt_saved_path, weights_only=True, ) if self.global_rank == 0: print(f"backup model to {ckpt_saved_path}.") # run full test sets eval on rank=0 device self.run_full_test_sets_eval() @rank_zero_only def run_full_test_sets_eval(self) -> None: start_t = time.time() pred_depths = None depth_gt = None full_testsets = self.trainer.datamodule.test_dataloader( # dataset_cfg=self.eval_data_cfg ) scores_dict = {} if not self.train_cfg.forward_depth_only: for score_tag in ("psnr", "ssim", "lpips"): scores_dict[score_tag] = {} for method_tag in ("deterministic", "probabilistic"): scores_dict[score_tag][method_tag] = [] # evaluate depth if self.train_cfg.viz_depth: for score_tag in ("abs_rel", "rmse", "a1"): scores_dict[score_tag] = {} for method_tag in ("deterministic", "probabilistic"): scores_dict[score_tag][method_tag] = [] self.benchmarker.clear_history() time_skip_first_n_steps = min( self.train_cfg.eval_time_skip_steps, len(full_testsets) ) time_skip_steps_dict = {"encoder": 0, "decoder": 0} for batch_idx, batch in tqdm( enumerate(full_testsets), total=min(len(full_testsets), self.train_cfg.eval_data_length), ): if batch_idx >= self.train_cfg.eval_data_length: break batch = self.data_shim(batch) batch = self.transfer_batch_to_device(batch, "cuda", dataloader_idx=0) # Render Gaussians. b, v, _, h, w = batch["target"]["image"].shape assert b == 1 if batch_idx < time_skip_first_n_steps: time_skip_steps_dict["encoder"] += 1 time_skip_steps_dict["decoder"] += v with self.benchmarker.time("encoder"): gaussians_probabilistic = self.encoder( batch["context"], self.global_step, deterministic=False, ) if isinstance(gaussians_probabilistic, dict): pred_depths = gaussians_probabilistic["depths"] if "depth" in batch["context"]: depth_gt = batch["context"]["depth"] gaussians_probabilistic = gaussians_probabilistic["gaussians"] if not self.train_cfg.forward_depth_only: with self.benchmarker.time("decoder", num_calls=v): output_probabilistic = self.decoder.forward( gaussians_probabilistic, batch["target"]["extrinsics"], batch["target"]["intrinsics"], batch["target"]["near"], batch["target"]["far"], (h, w), ) rgbs = [output_probabilistic.color[0]] tags = ["probabilistic"] if self.train_cfg.eval_deterministic: gaussians_deterministic = self.encoder( batch["context"], self.global_step, deterministic=True, ) output_deterministic = self.decoder.forward( gaussians_deterministic, batch["target"]["extrinsics"], batch["target"]["intrinsics"], batch["target"]["near"], batch["target"]["far"], (h, w), ) rgbs.append(output_deterministic.color[0]) tags.append("deterministic") # Compute validation metrics. rgb_gt = batch["target"]["image"][0] for tag, rgb in zip(tags, rgbs): scores_dict["psnr"][tag].append( compute_psnr(rgb_gt, rgb).mean().item() ) scores_dict["lpips"][tag].append( compute_lpips(rgb_gt, rgb).mean().item() ) scores_dict["ssim"][tag].append( compute_ssim(rgb_gt, rgb).mean().item() ) # summarise scores and log to logger for score_tag, methods in scores_dict.items(): for method_tag, cur_scores in methods.items(): if len(cur_scores) > 0: cur_mean = sum(cur_scores) / len(cur_scores) self.log(f"test/{score_tag}", cur_mean) # summarise run time for tag, times in self.benchmarker.execution_times.items(): times = times[int(time_skip_steps_dict[tag]) :] print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call") self.log(f"test/runtime_avg_{tag}", np.mean(times)) self.benchmarker.clear_history() overall_eval_time = time.time() - start_t print(f"Eval total time cost: {overall_eval_time:.3f}s") self.log("test/runtime_all", overall_eval_time) @rank_zero_only def render_video_wobble(self, batch: BatchedExample) -> None: # Two views are needed to get the wobble radius. _, v, _, _ = batch["context"]["extrinsics"].shape if v != 2: return def trajectory_fn(t): origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] delta = (origin_a - origin_b).norm(dim=-1) extrinsics = generate_wobble( batch["context"]["extrinsics"][:, 0], delta * 0.25, t, ) intrinsics = repeat( batch["context"]["intrinsics"][:, 0], "b i j -> b v i j", v=t.shape[0], ) return extrinsics, intrinsics return self.render_video_generic(batch, trajectory_fn, "wobble", num_frames=60) @rank_zero_only def render_video_interpolation(self, batch: BatchedExample) -> None: _, v, _, _ = batch["context"]["extrinsics"].shape def trajectory_fn(t): extrinsics = interpolate_extrinsics( batch["context"]["extrinsics"][0, 0], ( batch["context"]["extrinsics"][0, 1] if v == 2 else batch["target"]["extrinsics"][0, 0] ), t, ) intrinsics = interpolate_intrinsics( batch["context"]["intrinsics"][0, 0], ( batch["context"]["intrinsics"][0, 1] if v == 2 else batch["target"]["intrinsics"][0, 0] ), t, ) return extrinsics[None], intrinsics[None] return self.render_video_generic(batch, trajectory_fn, "rgb") @rank_zero_only def render_video_interpolation_exaggerated(self, batch: BatchedExample) -> None: # Two views are needed to get the wobble radius. _, v, _, _ = batch["context"]["extrinsics"].shape if v != 2: return def trajectory_fn(t): origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] delta = (origin_a - origin_b).norm(dim=-1) tf = generate_wobble_transformation( delta * 0.5, t, 5, scale_radius_with_t=False, ) extrinsics = interpolate_extrinsics( batch["context"]["extrinsics"][0, 0], ( batch["context"]["extrinsics"][0, 1] if v == 2 else batch["target"]["extrinsics"][0, 0] ), t * 5 - 2, ) intrinsics = interpolate_intrinsics( batch["context"]["intrinsics"][0, 0], ( batch["context"]["intrinsics"][0, 1] if v == 2 else batch["target"]["intrinsics"][0, 0] ), t * 5 - 2, ) return extrinsics @ tf, intrinsics[None] return self.render_video_generic( batch, trajectory_fn, "interpolation_exagerrated", num_frames=300, smooth=False, loop_reverse=False, ) @rank_zero_only def render_video_generic( self, batch: BatchedExample, trajectory_fn: TrajectoryFn, name: str, num_frames: int = 30, smooth: bool = True, loop_reverse: bool = True, ) -> None: # Render probabilistic estimate of scene. gaussians_prob = self.encoder(batch["context"], self.global_step, False) # gaussians_det = self.encoder(batch["context"], self.global_step, True) if isinstance(gaussians_prob, dict): gaussians_prob = gaussians_prob["gaussians"] t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=self.device) if smooth: t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 extrinsics, intrinsics = trajectory_fn(t) _, _, _, h, w = batch["context"]["image"].shape # Color-map the result. def depth_map(result): near = result[result > 0][:16_000_000].quantile(0.01).log() far = result.view(-1)[:16_000_000].quantile(0.99).log() result = result.log() result = 1 - (result - near) / (far - near) return apply_color_map_to_image(result, "turbo") near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames) far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames) output_prob = self.decoder.forward( gaussians_prob, extrinsics, intrinsics, near, far, (h, w), "depth" ) images_prob = [ vcat(rgb, depth) for rgb, depth in zip(output_prob.color[0], depth_map(output_prob.depth[0])) ] images = [ add_border( hcat( add_label(image_prob, "Prediction"), ) ) for image_prob, _ in zip(images_prob, images_prob) ] video = torch.stack(images) video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy() if loop_reverse: video = pack([video, video[::-1][1:-1]], "* c h w")[0] #swanlab不支持 # visualizations = { # f"video/{name}": wandb.Video(video[None], fps=30, format="mp4") # } # Since the PyTorch Lightning doesn't support video logging, log to wandb directly. # try: # wandb.log(visualizations) # except Exception: # assert isinstance(self.logger, LocalLogger) # for key, value in visualizations.items(): # tensor = value._prepare_video(value.data) # clip = mpy.ImageSequenceClip(list(tensor), fps=value._fps) # dir = LOG_PATH / key # dir.mkdir(exist_ok=True, parents=True) # clip.write_videofile( # str(dir / f"{self.global_step:0>6}.mp4"), logger=None # ) def configure_optimizers(self): if self.optimizer_cfg.lr_monodepth > 0: pretrained_params = [] new_params = [] for name, param in self.named_parameters(): if "pretrained" in name: pretrained_params.append(param) else: new_params.append(param) optimizer = torch.optim.AdamW( [ { "params": pretrained_params, "lr": self.optimizer_cfg.lr_monodepth, }, {"params": new_params, "lr": self.optimizer_cfg.lr}, ], weight_decay=self.optimizer_cfg.weight_decay, ) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, [self.optimizer_cfg.lr_monodepth, self.optimizer_cfg.lr], self.trainer.max_steps + 10, pct_start=0.01, cycle_momentum=False, anneal_strategy="cos", ) else: optimizer = optim.AdamW( self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay, ) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, self.optimizer_cfg.lr, self.trainer.max_steps + 10, pct_start=0.01, cycle_momentum=False, anneal_strategy="cos", ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step", "frequency": 1, }, } # ----------------- MOD: debug helper functions ----------------- def snapshot_params(self): """Return a snapshot dict name->cpu clone for all trainable params.""" return {n: p.detach().cpu().clone() for n, p in self.named_parameters() if p.requires_grad} def compare_params_snapshot(self, before_snap, tol=0.0, show_n=50): """Compare saved snapshot with current params; report params with zero update.""" if before_snap is None: print("[CHECK] no before-snapshot provided") return zero_updates = [] small_updates = [] for name, p in self.named_parameters(): if not p.requires_grad: continue if name not in before_snap: continue before = before_snap[name] cur = p.detach().cpu() diff = (cur - before).abs().sum().item() if diff <= tol: zero_updates.append((name, diff)) else: if diff < 1e-12: small_updates.append((name, diff)) print(f"[CHECK] zero-updates={len(zero_updates)}, very-small-updates={len(small_updates)}") if zero_updates: print(" params with (nearly) zero update (first {}):".format(show_n)) for n, d in zero_updates[:show_n]: print(f" {n} diff_sum={d:.3e}") def list_params_with_no_grad_after_backward(self, tiny_thresh=1e-12, show_n=50): """Call after backward: list params with grad None or extremely small norm.""" none_grad = [] very_small = [] nan_inf = [] for name, p in self.named_parameters(): if not p.requires_grad: continue g = p.grad if g is None: none_grad.append(name) else: if torch.isnan(g).any() or torch.isinf(g).any(): nan_inf.append(name) try: norm = float(g.norm().item()) except Exception: norm = 0.0 if norm < tiny_thresh: very_small.append((name, norm)) print(f"[CHECK] grads None: {len(none_grad)}, very_small (<{tiny_thresh}): {len(very_small)}, nan/inf: {len(nan_inf)}") if none_grad: print(" grads is None (first {}):".format(show_n)) for n in none_grad[:show_n]: print(" ", n) if very_small: print(" grads very small (first {}):".format(show_n)) for n, norm in very_small[:show_n]: print(f" {n} norm={norm:.3e}") if nan_inf: print(" grads have NaN/Inf (first {}):".format(show_n)) for n in nan_inf[:show_n]: print(" ", n) def list_frozen_params(self, show_n=50): frozen = [(n, p.shape) for n, p in self.named_parameters() if not p.requires_grad] print(f"[CHECK] frozen params count = {len(frozen)}") for name, shape in frozen[:show_n]: print(" ", name, shape) def check_optimizer_coverage(self): """Check whether configured optimizer(s) include all trainable params""" try: opt_or_list = self.optimizers() except Exception as e: # may be called too early before trainer.configure_optimizers completes print("[CHECK] cannot access optimizer from here:", e) return opts = opt_or_list if isinstance(opt_or_list, (list, tuple)) else [opt_or_list] opt_param_ids = set(id(p) for o in opts for g in o.param_groups for p in g['params']) missing = [n for n, p in self.named_parameters() if p.requires_grad and id(p) not in opt_param_ids] print(f"[CHECK] optimizer missing params count = {len(missing)} (show up to 50):") for n in missing[:50]: print(" ", n) print(" optimizer lrs per opt:", [[g['lr'] for g in o.param_groups] for o in opts]) # ----------------- MOD: hooks ----------------- def on_after_backward(self): """Called by Lightning after backward(): print grad stats (only if debug enabled).""" if not getattr(self, "_check_param_updates", False): return try: print(f"[on_after_backward] step={self.global_step}") # print a few grad stats self.list_params_with_no_grad_after_backward() except Exception as e: print("on_after_backward debug error:", e) def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): """Called after optimizer.step() in Lightning — compare params snapshot here.""" # Only run our heavy compare when enabled if not getattr(self, "_check_param_updates", False): return try: # Compare params snapshot (if any) before = getattr(self, "_before_params_snapshot", None) if before is None: print("[on_train_batch_end] no before snapshot found") else: self.compare_params_snapshot(before, tol=0.0) # clear snapshot for next step self._before_params_snapshot = None # Also check optimizer coverage (helps discover params not in optimizer) self.check_optimizer_coverage() except Exception as e: print("on_train_batch_end debug error:", e)