|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Optional, Protocol, runtime_checkable |
|
|
|
|
|
import moviepy.editor as mpy |
|
|
import torch |
|
|
|
|
|
import swanlab as wandb |
|
|
|
|
|
from einops import pack, rearrange, repeat, einsum |
|
|
from jaxtyping import Float |
|
|
from pytorch_lightning import LightningModule |
|
|
|
|
|
from swanlab.integration.pytorch_lightning import SwanLabLogger |
|
|
|
|
|
|
|
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
from torch import Tensor, nn, optim |
|
|
import numpy as np |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
from tqdm import tqdm |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
from ..dataset.data_module import get_data_shim |
|
|
from ..dataset.types import BatchedExample |
|
|
from ..dataset import DatasetCfg |
|
|
from ..evaluation.metrics import compute_lpips, compute_psnr, compute_ssim |
|
|
from ..global_cfg import get_cfg |
|
|
from ..loss import Loss |
|
|
from ..misc.benchmarker import Benchmarker |
|
|
from ..misc.image_io import prep_image, save_image, save_video |
|
|
from ..misc.LocalLogger import LOG_PATH, LocalLogger |
|
|
from ..misc.step_tracker import StepTracker |
|
|
from ..visualization.annotation import add_label |
|
|
from ..visualization.camera_trajectory.interpolation import ( |
|
|
interpolate_extrinsics, |
|
|
interpolate_intrinsics, |
|
|
) |
|
|
from ..visualization.camera_trajectory.wobble import ( |
|
|
generate_wobble, |
|
|
generate_wobble_transformation, |
|
|
) |
|
|
from ..visualization.color_map import apply_color_map_to_image |
|
|
from ..visualization.layout import add_border, hcat, vcat |
|
|
from ..visualization.validation_in_3d import render_cameras, render_projections |
|
|
from .decoder.decoder import Decoder, DepthRenderingMode, DecoderOutput |
|
|
from .encoder import Encoder |
|
|
from .encoder.visualization.encoder_visualizer import EncoderVisualizer |
|
|
from src.visualization.vis_depth import viz_depth_tensor |
|
|
from PIL import Image |
|
|
from ..misc.stablize_camera import render_stabilization_path |
|
|
from .ply_export import save_gaussian_ply |
|
|
from .encoder.encoder_volsplat import print_mem |
|
|
import MinkowskiEngine as ME |
|
|
|
|
|
|
|
|
from ..test.visual import save_output_images |
|
|
from ..test.export_ply import export_raw_points_step |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OptimizerCfg: |
|
|
lr: float |
|
|
warm_up_steps: int |
|
|
lr_monodepth: float |
|
|
weight_decay: float |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TestCfg: |
|
|
output_path: Path |
|
|
compute_scores: bool |
|
|
save_image: bool |
|
|
save_video: bool |
|
|
eval_time_skip_steps: int |
|
|
save_gt_image: bool |
|
|
save_input_images: bool |
|
|
save_depth: bool |
|
|
save_depth_concat_img: bool |
|
|
save_depth_npy: bool |
|
|
save_gaussian: bool |
|
|
render_chunk_size: int | None |
|
|
stablize_camera: bool |
|
|
stab_camera_kernel: int |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainCfg: |
|
|
depth_mode: DepthRenderingMode | None |
|
|
extended_visualization: bool |
|
|
print_log_every_n_steps: int |
|
|
eval_model_every_n_val: int |
|
|
eval_data_length: int |
|
|
eval_deterministic: bool |
|
|
eval_time_skip_steps: int |
|
|
eval_save_model: bool |
|
|
l1_loss: bool |
|
|
intermediate_loss_weight: float |
|
|
no_viz_video: bool |
|
|
viz_depth: bool |
|
|
forward_depth_only: bool |
|
|
train_ignore_large_loss: float |
|
|
no_log_projections: bool |
|
|
|
|
|
|
|
|
@runtime_checkable |
|
|
class TrajectoryFn(Protocol): |
|
|
def __call__( |
|
|
self, |
|
|
t: Float[Tensor, " t"], |
|
|
) -> tuple[ |
|
|
Float[Tensor, "batch view 4 4"], |
|
|
Float[Tensor, "batch view 3 3"], |
|
|
]: |
|
|
pass |
|
|
|
|
|
|
|
|
class ModelWrapper(LightningModule): |
|
|
logger: Optional[SwanLabLogger] |
|
|
encoder: nn.Module |
|
|
encoder_visualizer: Optional[EncoderVisualizer] |
|
|
decoder: Decoder |
|
|
losses: nn.ModuleList |
|
|
optimizer_cfg: OptimizerCfg |
|
|
test_cfg: TestCfg |
|
|
train_cfg: TrainCfg |
|
|
step_tracker: StepTracker | None |
|
|
eval_data_cfg: Optional[DatasetCfg | None] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
optimizer_cfg: OptimizerCfg, |
|
|
test_cfg: TestCfg, |
|
|
train_cfg: TrainCfg, |
|
|
encoder: Encoder, |
|
|
encoder_visualizer: Optional[EncoderVisualizer], |
|
|
decoder: Decoder, |
|
|
losses: list[Loss], |
|
|
step_tracker: StepTracker | None, |
|
|
eval_data_cfg: Optional[DatasetCfg | None] = None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.optimizer_cfg = optimizer_cfg |
|
|
self.test_cfg = test_cfg |
|
|
self.train_cfg = train_cfg |
|
|
self.step_tracker = step_tracker |
|
|
self.eval_data_cfg = eval_data_cfg |
|
|
|
|
|
|
|
|
self.encoder = encoder |
|
|
self.encoder_visualizer = encoder_visualizer |
|
|
self.decoder = decoder |
|
|
self.data_shim = get_data_shim(self.encoder) |
|
|
self.losses = nn.ModuleList(losses) |
|
|
|
|
|
|
|
|
self.benchmarker = Benchmarker() |
|
|
self.eval_cnt = 0 |
|
|
|
|
|
if self.test_cfg.compute_scores: |
|
|
self.test_step_outputs = {} |
|
|
self.time_skip_steps_dict = {"encoder": 0, "decoder": 0} |
|
|
|
|
|
|
|
|
self.benchmarker = Benchmarker() |
|
|
self.eval_cnt = 0 |
|
|
|
|
|
if self.test_cfg.compute_scores: |
|
|
self.test_step_outputs = {} |
|
|
self.time_skip_steps_dict = {"encoder": 0, "decoder": 0} |
|
|
|
|
|
|
|
|
self._check_param_updates = False |
|
|
|
|
|
self._before_params_snapshot = None |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
batch: BatchedExample = self.data_shim(batch) |
|
|
_, _, _, h, w = batch["target"]["image"].shape |
|
|
_, views, _, _, _ = batch["context"]["image"].shape |
|
|
|
|
|
|
|
|
|
|
|
if getattr(self, "_check_param_updates", False): |
|
|
try: |
|
|
self._before_params_snapshot = self.snapshot_params() |
|
|
except Exception as e: |
|
|
print("[DEBUG] failed to snapshot params:", e) |
|
|
|
|
|
|
|
|
|
|
|
print(f"Training step{self.global_step},Number of images:{views}:scene IDs:{batch['scene']}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print_mem("before encoder") |
|
|
|
|
|
gaussians = self.encoder( |
|
|
batch["context"], self.global_step, False, scene_names=batch["scene"] |
|
|
) |
|
|
print_mem("after encoder") |
|
|
|
|
|
|
|
|
if isinstance(gaussians, dict) and len(gaussians) == 2: |
|
|
supervise_intermediate_depth = False |
|
|
pred_depths = gaussians["depths"] |
|
|
gaussians = gaussians["gaussians"] |
|
|
|
|
|
if isinstance(gaussians, dict) and len(gaussians) == 3: |
|
|
supervise_intermediate_depth = True |
|
|
pred_depths = gaussians["depths"] |
|
|
intermediate_gaussians = gaussians["intermediate_gaussians"] |
|
|
gaussians = gaussians["gaussians"] |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
pd = pred_depths.detach() |
|
|
|
|
|
B = pd.shape[0] |
|
|
|
|
|
pd = torch.nan_to_num(pd, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
per_sample_sum = pd.abs().view(B, -1).sum(dim=1) |
|
|
zero_mask = (per_sample_sum == 0) |
|
|
|
|
|
if zero_mask.any(): |
|
|
zero_idx = torch.nonzero(zero_mask, as_tuple=False).squeeze(1) |
|
|
|
|
|
if zero_idx.numel() == 1: |
|
|
zero_idx = [int(zero_idx.item())] |
|
|
else: |
|
|
zero_idx = [int(i.item()) for i in zero_idx] |
|
|
|
|
|
|
|
|
scene_ids = [] |
|
|
for i in zero_idx: |
|
|
try: |
|
|
s = batch["scene"][i] |
|
|
except Exception: |
|
|
|
|
|
s = None |
|
|
|
|
|
try: |
|
|
if isinstance(s, torch.Tensor): |
|
|
|
|
|
if s.numel() == 1: |
|
|
scene_ids.append(s.item()) |
|
|
else: |
|
|
|
|
|
scene_ids.append(s.cpu().tolist()) |
|
|
else: |
|
|
scene_ids.append(s) |
|
|
except Exception: |
|
|
scene_ids.append(str(s)) |
|
|
|
|
|
|
|
|
is_rank0 = True |
|
|
try: |
|
|
is_rank0 = getattr(self.trainer, "global_rank", getattr(self, "global_rank", 0)) == 0 |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if is_rank0: |
|
|
print(f"[STOPPING] pred_depths all-zero detected for batch indices {zero_idx}; scene IDs: {scene_ids}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.trainer.should_stop = True |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
raise RuntimeError(f"Stopping training because pred_depths are all zero for scenes: {scene_ids}") |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print("[DEBUG] pred_depths zero-check failed or triggered stop. info:", e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print_mem("before decoder") |
|
|
|
|
|
if gaussians.means.size(0) != batch["target"]["extrinsics"].size(0): |
|
|
supervise_intermediate_depth = True |
|
|
assert gaussians.means.size(0) % batch["target"]["extrinsics"].size(0) == 0 |
|
|
num_depths = gaussians.means.size(0) // batch["target"]["extrinsics"].size( |
|
|
0 |
|
|
) |
|
|
|
|
|
target_extrinsics = torch.cat( |
|
|
[batch["target"]["extrinsics"]] * num_depths, dim=0 |
|
|
) |
|
|
target_intrinsics = torch.cat( |
|
|
[batch["target"]["intrinsics"]] * num_depths, dim=0 |
|
|
) |
|
|
target_near = torch.cat([batch["target"]["near"]] * num_depths, dim=0) |
|
|
target_far = torch.cat([batch["target"]["far"]] * num_depths, dim=0) |
|
|
|
|
|
output_all = self.decoder.forward( |
|
|
gaussians, |
|
|
target_extrinsics, |
|
|
target_intrinsics, |
|
|
target_near, |
|
|
target_far, |
|
|
(h, w), |
|
|
depth_mode=self.train_cfg.depth_mode, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
batch_size = batch["target"]["extrinsics"].size(0) |
|
|
|
|
|
output_intermediate = DecoderOutput( |
|
|
color=output_all.color[:-batch_size], |
|
|
depth=( |
|
|
output_all.depth[:-batch_size] |
|
|
if output_all.depth is not None |
|
|
else None |
|
|
), |
|
|
) |
|
|
output = DecoderOutput( |
|
|
color=output_all.color[-batch_size:], |
|
|
depth=( |
|
|
output_all.depth[-batch_size:] |
|
|
if output_all.depth is not None |
|
|
else None |
|
|
), |
|
|
) |
|
|
|
|
|
else: |
|
|
output = self.decoder.forward( |
|
|
gaussians, |
|
|
batch["target"]["extrinsics"], |
|
|
batch["target"]["intrinsics"], |
|
|
batch["target"]["near"], |
|
|
batch["target"]["far"], |
|
|
(h, w), |
|
|
depth_mode=self.train_cfg.depth_mode, |
|
|
) |
|
|
|
|
|
if supervise_intermediate_depth: |
|
|
output_intermediate = self.decoder.forward( |
|
|
intermediate_gaussians, |
|
|
batch["target"]["extrinsics"], |
|
|
batch["target"]["intrinsics"], |
|
|
batch["target"]["near"], |
|
|
batch["target"]["far"], |
|
|
(h, w), |
|
|
depth_mode=self.train_cfg.depth_mode, |
|
|
) |
|
|
|
|
|
print_mem("after decoder") |
|
|
|
|
|
target_gt = batch["target"]["image"] |
|
|
|
|
|
|
|
|
psnr_probabilistic = compute_psnr( |
|
|
rearrange(target_gt, "b v c h w -> (b v) c h w"), |
|
|
rearrange(output.color, "b v c h w -> (b v) c h w"), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.log("train/psnr", psnr_probabilistic.mean()) |
|
|
|
|
|
|
|
|
total_loss = 0 |
|
|
|
|
|
valid_depth_mask = None |
|
|
|
|
|
for loss_fn in self.losses: |
|
|
if loss_fn.name == "mse": |
|
|
loss = loss_fn.forward( |
|
|
output, |
|
|
batch, |
|
|
gaussians, |
|
|
self.global_step, |
|
|
l1_loss=self.train_cfg.l1_loss, |
|
|
clamp_large_error=self.train_cfg.train_ignore_large_loss, |
|
|
valid_depth_mask=valid_depth_mask, |
|
|
) |
|
|
else: |
|
|
loss = loss_fn.forward( |
|
|
output, |
|
|
batch, |
|
|
gaussians, |
|
|
self.global_step, |
|
|
valid_depth_mask=valid_depth_mask, |
|
|
) |
|
|
self.log(f"loss/{loss_fn.name}", loss) |
|
|
total_loss = total_loss + loss |
|
|
|
|
|
|
|
|
if supervise_intermediate_depth: |
|
|
for loss_fn in self.losses: |
|
|
batch_size = batch["target"]["extrinsics"].size(0) |
|
|
if output_intermediate.color.size(0) != batch_size: |
|
|
assert output_intermediate.color.size(0) % batch_size == 0 |
|
|
num_intermediate = output_intermediate.color.size(0) // batch_size |
|
|
intermediate_loss = 0 |
|
|
for i in range(num_intermediate): |
|
|
curr_output = DecoderOutput( |
|
|
color=output_intermediate.color[ |
|
|
(batch_size * i) : (batch_size * (i + 1)) |
|
|
], |
|
|
depth=( |
|
|
output_intermediate.depth[ |
|
|
(batch_size * i) : (batch_size * (i + 1)) |
|
|
] |
|
|
if output_intermediate.depth is not None |
|
|
else None |
|
|
), |
|
|
) |
|
|
curr_loss_weight = self.train_cfg.intermediate_loss_weight ** ( |
|
|
num_intermediate - i |
|
|
) |
|
|
|
|
|
if loss_fn.name == "mse": |
|
|
loss = loss_fn.forward( |
|
|
curr_output, |
|
|
batch, |
|
|
gaussians, |
|
|
self.global_step, |
|
|
l1_loss=self.train_cfg.l1_loss, |
|
|
clamp_large_error=self.train_cfg.train_ignore_large_loss, |
|
|
valid_depth_mask=valid_depth_mask, |
|
|
) |
|
|
else: |
|
|
loss = loss_fn.forward( |
|
|
curr_output, |
|
|
batch, |
|
|
gaussians, |
|
|
self.global_step, |
|
|
valid_depth_mask=valid_depth_mask, |
|
|
) |
|
|
|
|
|
intermediate_loss = intermediate_loss + curr_loss_weight * loss |
|
|
|
|
|
self.log(f"loss/{loss_fn.name}_intermediate", intermediate_loss) |
|
|
total_loss = total_loss + intermediate_loss |
|
|
else: |
|
|
if loss_fn.name == "mse": |
|
|
loss = loss_fn.forward( |
|
|
output_intermediate, |
|
|
batch, |
|
|
gaussians, |
|
|
self.global_step, |
|
|
l1_loss=self.train_cfg.l1_loss, |
|
|
clamp_large_error=self.train_cfg.train_ignore_large_loss, |
|
|
valid_depth_mask=valid_depth_mask, |
|
|
) |
|
|
else: |
|
|
loss = loss_fn.forward( |
|
|
output_intermediate, |
|
|
batch, |
|
|
gaussians, |
|
|
self.global_step, |
|
|
valid_depth_mask=valid_depth_mask, |
|
|
) |
|
|
self.log(f"loss/{loss_fn.name}_intermediate", loss) |
|
|
total_loss = ( |
|
|
total_loss + self.train_cfg.intermediate_loss_weight * loss |
|
|
) |
|
|
|
|
|
self.log("loss/total", total_loss) |
|
|
|
|
|
if ( |
|
|
self.global_rank == 0 |
|
|
and self.global_step % self.train_cfg.print_log_every_n_steps == 0 |
|
|
): |
|
|
print( |
|
|
f"train step {self.global_step}; " |
|
|
f"scene = {[x[:20] for x in batch['scene']]}; " |
|
|
f"context = {batch['context']['index'].tolist()}; " |
|
|
f"bound = [{batch['context']['near'].detach().cpu().numpy().mean()} " |
|
|
f"{batch['context']['far'].detach().cpu().numpy().mean()}]; " |
|
|
f"loss = {total_loss:.6f}" |
|
|
) |
|
|
self.log("info/near", batch["context"]["near"].detach().cpu().numpy().mean()) |
|
|
self.log("info/far", batch["context"]["far"].detach().cpu().numpy().mean()) |
|
|
self.log("info/global_step", self.global_step) |
|
|
|
|
|
|
|
|
if self.step_tracker is not None: |
|
|
self.step_tracker.set_step(self.global_step) |
|
|
|
|
|
if self.global_step == 5 and self.global_rank == 0: |
|
|
os.system("nvidia-smi") |
|
|
|
|
|
return total_loss |
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
batch: BatchedExample = self.data_shim(batch) |
|
|
b, v, _, h, w = batch["target"]["image"].shape |
|
|
assert b == 1 |
|
|
|
|
|
|
|
|
pred_depths = None |
|
|
|
|
|
|
|
|
if self.test_cfg.save_input_images: |
|
|
(scene,) = batch["scene"] |
|
|
self.test_cfg.output_path = os.path.join(get_cfg()["output_dir"], "metrics") |
|
|
path = Path(get_cfg()["output_dir"]) |
|
|
|
|
|
input_images = batch["context"]["image"][0] |
|
|
index = batch["context"]["index"][0] |
|
|
for idx, color in zip(index, input_images): |
|
|
save_image(color, path / "images" / scene / f"color/input_{idx:0>6}.png") |
|
|
|
|
|
|
|
|
if self.test_cfg.save_depth or self.test_cfg.save_gaussian: |
|
|
visualization_dump = {} |
|
|
else: |
|
|
visualization_dump = None |
|
|
|
|
|
|
|
|
with self.benchmarker.time("encoder"): |
|
|
gaussians = self.encoder( |
|
|
batch["context"], |
|
|
self.global_step, |
|
|
deterministic=False, |
|
|
visualization_dump=visualization_dump, |
|
|
) |
|
|
|
|
|
if isinstance(gaussians, dict): |
|
|
pred_depths = gaussians["depths"] |
|
|
if "depth" in batch["context"]: |
|
|
depth_gt = batch["context"]["depth"] |
|
|
gaussians = gaussians["gaussians"] |
|
|
|
|
|
|
|
|
if self.test_cfg.save_gaussian: |
|
|
scene = batch["scene"][0] |
|
|
save_path = Path(get_cfg()['output_dir']) / 'gaussians' / (scene + '.ply') |
|
|
save_gaussian_ply(gaussians, visualization_dump, batch, save_path) |
|
|
|
|
|
if not self.train_cfg.forward_depth_only: |
|
|
with self.benchmarker.time("decoder", num_calls=v): |
|
|
|
|
|
camera_poses = batch["target"]["extrinsics"] |
|
|
|
|
|
if self.test_cfg.stablize_camera: |
|
|
stable_poses = render_stabilization_path( |
|
|
camera_poses[0].detach().cpu().numpy(), |
|
|
k_size=self.test_cfg.stab_camera_kernel, |
|
|
) |
|
|
|
|
|
stable_poses = list( |
|
|
map( |
|
|
lambda x: np.concatenate( |
|
|
(x, np.array([[0.0, 0.0, 0.0, 1.0]])), axis=0 |
|
|
), |
|
|
stable_poses, |
|
|
) |
|
|
) |
|
|
stable_poses = torch.from_numpy(np.stack(stable_poses, axis=0)).to( |
|
|
camera_poses |
|
|
) |
|
|
camera_poses = stable_poses.unsqueeze(0) |
|
|
|
|
|
if self.test_cfg.render_chunk_size is not None: |
|
|
chunk_size = self.test_cfg.render_chunk_size |
|
|
num_chunks = math.ceil(camera_poses.shape[1] / chunk_size) |
|
|
|
|
|
output = None |
|
|
for i in range(num_chunks): |
|
|
start = chunk_size * i |
|
|
end = chunk_size * (i + 1) |
|
|
|
|
|
render_intrinsics = batch["target"]["intrinsics"] |
|
|
render_near = batch["target"]["near"] |
|
|
render_far = batch["target"]["far"] |
|
|
|
|
|
curr_output = self.decoder.forward( |
|
|
gaussians, |
|
|
camera_poses[:, start:end], |
|
|
render_intrinsics[:, start:end], |
|
|
render_near[:, start:end], |
|
|
render_far[:, start:end], |
|
|
(h, w), |
|
|
depth_mode=None, |
|
|
) |
|
|
|
|
|
if i == 0: |
|
|
output = curr_output |
|
|
else: |
|
|
|
|
|
output.color = torch.cat( |
|
|
(output.color, curr_output.color), dim=1 |
|
|
) |
|
|
|
|
|
else: |
|
|
output = self.decoder.forward( |
|
|
gaussians, |
|
|
camera_poses, |
|
|
batch["target"]["intrinsics"], |
|
|
batch["target"]["near"], |
|
|
batch["target"]["far"], |
|
|
(h, w), |
|
|
depth_mode=None, |
|
|
) |
|
|
|
|
|
(scene,) = batch["scene"] |
|
|
self.test_cfg.output_path = os.path.join(get_cfg()["output_dir"], "metrics") |
|
|
path = Path(get_cfg()["output_dir"]) |
|
|
|
|
|
|
|
|
if self.test_cfg.save_depth: |
|
|
if self.train_cfg.forward_depth_only: |
|
|
depth = pred_depths[0].cpu().detach() |
|
|
else: |
|
|
depth = ( |
|
|
visualization_dump["depth"][0, :, :, :, 0, 0].cpu().detach() |
|
|
) |
|
|
|
|
|
index = batch["context"]["index"][0] |
|
|
|
|
|
if self.test_cfg.save_depth_concat_img: |
|
|
|
|
|
image = batch['context']['image'][0] |
|
|
image = rearrange(image, "b c h w -> h (b w) c") |
|
|
image_concat = (image.detach().cpu().numpy() * 255).astype(np.uint8) |
|
|
|
|
|
depth_concat = [] |
|
|
|
|
|
for idx, depth_i in zip(index, depth): |
|
|
depth_viz = viz_depth_tensor( |
|
|
1.0 / depth_i, return_numpy=True |
|
|
) |
|
|
|
|
|
if self.test_cfg.save_depth_concat_img: |
|
|
depth_concat.append(depth_viz) |
|
|
|
|
|
save_path = path / "images" / scene / "depth" / f"{idx:0>6}.png" |
|
|
save_dir = os.path.dirname(save_path) |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
Image.fromarray(depth_viz).save(save_path) |
|
|
|
|
|
|
|
|
if self.test_cfg.save_depth_npy: |
|
|
depth_npy = depth_i.detach().cpu().numpy() |
|
|
save_path = path / "images" / scene / "depth" / f"{idx:0>6}.npy" |
|
|
save_dir = os.path.dirname(save_path) |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
np.save(save_path, depth_npy) |
|
|
|
|
|
if self.test_cfg.save_depth_concat_img: |
|
|
depth_concat = np.concatenate(depth_concat, axis=1) |
|
|
concat = np.concatenate((image_concat, depth_concat), axis=0) |
|
|
|
|
|
save_path = path / "images" / scene / "depth" / f"img_depth_{scene}.png" |
|
|
save_dir = os.path.dirname(save_path) |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
Image.fromarray(concat).save(save_path) |
|
|
|
|
|
if self.train_cfg.forward_depth_only: |
|
|
return |
|
|
|
|
|
images_prob = output.color[0] |
|
|
rgb_gt = batch["target"]["image"][0] |
|
|
|
|
|
|
|
|
if self.test_cfg.save_image: |
|
|
if self.test_cfg.save_gt_image: |
|
|
for index, color, gt in zip( |
|
|
batch["target"]["index"][0], images_prob, rgb_gt |
|
|
): |
|
|
save_image(color, path / "images" / scene / f"color/{index:0>6}.png") |
|
|
save_image(gt, path / "images" / scene / f"color/{index:0>6}_gt.png") |
|
|
else: |
|
|
for index, color in zip(batch["target"]["index"][0], images_prob): |
|
|
save_image(color, path / "images" / scene / f"color/{index:0>6}.png") |
|
|
|
|
|
|
|
|
if self.test_cfg.save_video: |
|
|
frame_str = "_".join([str(x.item()) for x in batch["context"]["index"][0]]) |
|
|
save_video( |
|
|
[a for a in images_prob], |
|
|
path / "videos" / f"{scene}_frame_{frame_str}.mp4", |
|
|
) |
|
|
|
|
|
|
|
|
if self.test_cfg.compute_scores: |
|
|
if batch_idx < self.test_cfg.eval_time_skip_steps: |
|
|
self.time_skip_steps_dict["encoder"] += 1 |
|
|
self.time_skip_steps_dict["decoder"] += v |
|
|
|
|
|
if not self.train_cfg.forward_depth_only: |
|
|
rgb = images_prob |
|
|
|
|
|
if f"psnr" not in self.test_step_outputs: |
|
|
self.test_step_outputs[f"psnr"] = [] |
|
|
if f"ssim" not in self.test_step_outputs: |
|
|
self.test_step_outputs[f"ssim"] = [] |
|
|
if f"lpips" not in self.test_step_outputs: |
|
|
self.test_step_outputs[f"lpips"] = [] |
|
|
|
|
|
self.test_step_outputs[f"psnr"].append( |
|
|
compute_psnr(rgb_gt, rgb).mean().item() |
|
|
) |
|
|
self.test_step_outputs[f"ssim"].append( |
|
|
compute_ssim(rgb_gt, rgb).mean().item() |
|
|
) |
|
|
self.test_step_outputs[f"lpips"].append( |
|
|
compute_lpips(rgb_gt, rgb).mean().item() |
|
|
) |
|
|
|
|
|
def on_test_end(self) -> None: |
|
|
out_dir = Path(self.test_cfg.output_path) |
|
|
saved_scores = {} |
|
|
if self.test_cfg.compute_scores: |
|
|
self.benchmarker.dump_memory(out_dir / "peak_memory.json") |
|
|
self.benchmarker.dump(out_dir / "benchmark.json") |
|
|
|
|
|
for metric_name, metric_scores in self.test_step_outputs.items(): |
|
|
avg_scores = sum(metric_scores) / len(metric_scores) |
|
|
saved_scores[metric_name] = avg_scores |
|
|
print(metric_name, avg_scores) |
|
|
with (out_dir / f"scores_{metric_name}_all.json").open("w") as f: |
|
|
json.dump(metric_scores, f) |
|
|
metric_scores.clear() |
|
|
|
|
|
for tag, times in self.benchmarker.execution_times.items(): |
|
|
times = times[int(self.time_skip_steps_dict[tag]) :] |
|
|
saved_scores[tag] = [len(times), np.mean(times)] |
|
|
print( |
|
|
f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call" |
|
|
) |
|
|
self.time_skip_steps_dict[tag] = 0 |
|
|
|
|
|
with (out_dir / f"scores_all_avg.json").open("w") as f: |
|
|
json.dump(saved_scores, f) |
|
|
self.benchmarker.clear_history() |
|
|
else: |
|
|
self.benchmarker.dump(out_dir / "benchmark.json") |
|
|
self.benchmarker.dump_memory(out_dir / "peak_memory.json") |
|
|
self.benchmarker.summarize() |
|
|
|
|
|
@rank_zero_only |
|
|
def validation_step(self, batch, batch_idx): |
|
|
batch: BatchedExample = self.data_shim(batch) |
|
|
|
|
|
if self.global_rank == 0: |
|
|
print( |
|
|
f"validation step {self.global_step}; " |
|
|
f"scene = {[a[:20] for a in batch['scene']]}; " |
|
|
f"context = {batch['context']['index'].tolist()}" |
|
|
) |
|
|
|
|
|
|
|
|
b, _, _, h, w = batch["target"]["image"].shape |
|
|
assert b == 1 |
|
|
gaussians_softmax = self.encoder( |
|
|
batch["context"], |
|
|
self.global_step, |
|
|
deterministic=False, |
|
|
) |
|
|
|
|
|
pred_depths = None |
|
|
|
|
|
if isinstance(gaussians_softmax, dict): |
|
|
pred_depths = gaussians_softmax["depths"] |
|
|
if "depth" in batch["context"]: |
|
|
depth_gt = batch["context"]["depth"] |
|
|
gaussians_softmax = gaussians_softmax["gaussians"] |
|
|
|
|
|
if not self.train_cfg.forward_depth_only: |
|
|
output_softmax = self.decoder.forward( |
|
|
gaussians_softmax, |
|
|
batch["target"]["extrinsics"], |
|
|
batch["target"]["intrinsics"], |
|
|
batch["target"]["near"], |
|
|
batch["target"]["far"], |
|
|
(h, w), |
|
|
) |
|
|
rgb_softmax = output_softmax.color[0] |
|
|
|
|
|
|
|
|
rgb_gt = batch["target"]["image"][0] |
|
|
for tag, rgb in zip(("val",), (rgb_softmax,)): |
|
|
psnr = compute_psnr(rgb_gt, rgb).mean() |
|
|
self.log(f"val/psnr_{tag}", psnr) |
|
|
lpips = compute_lpips(rgb_gt, rgb).mean() |
|
|
self.log(f"val/lpips_{tag}", lpips) |
|
|
ssim = compute_ssim(rgb_gt, rgb).mean() |
|
|
self.log(f"val/ssim_{tag}", ssim) |
|
|
|
|
|
|
|
|
if pred_depths is not None: |
|
|
|
|
|
pred_depths = pred_depths[0] |
|
|
|
|
|
|
|
|
if pred_depths.shape[1:] != batch["context"]["image"].shape[-2:]: |
|
|
pred_depths = F.interpolate( |
|
|
pred_depths.unsqueeze(1), |
|
|
size=batch["context"]["image"].shape[-2:], |
|
|
mode="bilinear", |
|
|
align_corners=True, |
|
|
).squeeze(1) |
|
|
|
|
|
inverse_depth_pred = 1.0 / pred_depths |
|
|
|
|
|
concat = [] |
|
|
for i in range(inverse_depth_pred.size(0)): |
|
|
concat.append(inverse_depth_pred[i]) |
|
|
|
|
|
concat = torch.cat(concat, dim=1) |
|
|
|
|
|
depth_viz = viz_depth_tensor(concat.cpu().detach()) |
|
|
|
|
|
|
|
|
input_images = batch["context"]["image"][0] |
|
|
concat_img = [img for img in input_images] |
|
|
concat_img = torch.cat(concat_img, dim=-1) * 255 |
|
|
|
|
|
concat = torch.cat( |
|
|
(concat_img.cpu().detach(), depth_viz), dim=1 |
|
|
) |
|
|
|
|
|
self.logger.log_image( |
|
|
"depth", |
|
|
[concat], |
|
|
step=self.global_step, |
|
|
caption=batch["scene"], |
|
|
) |
|
|
|
|
|
if not self.train_cfg.forward_depth_only: |
|
|
|
|
|
comparison = hcat( |
|
|
add_label(vcat(*batch["context"]["image"][0]), "Context"), |
|
|
add_label(vcat(*rgb_gt), "Target (Ground Truth)"), |
|
|
add_label(vcat(*rgb_softmax), "Target (Prediction)"), |
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
self.logger.log_image( |
|
|
"comparison", |
|
|
[prep_image(add_border(comparison))], |
|
|
step=self.global_step, |
|
|
caption=batch["scene"], |
|
|
) |
|
|
|
|
|
if not self.train_cfg.no_log_projections: |
|
|
|
|
|
projections = hcat( |
|
|
*render_projections( |
|
|
gaussians_softmax, |
|
|
256, |
|
|
extra_label="(Prediction)", |
|
|
)[0] |
|
|
) |
|
|
self.logger.log_image( |
|
|
"projection", |
|
|
[prep_image(add_border(projections))], |
|
|
step=self.global_step, |
|
|
) |
|
|
|
|
|
|
|
|
cameras = hcat(*render_cameras(batch, 256)) |
|
|
self.logger.log_image( |
|
|
"cameras", [prep_image(add_border(cameras))], step=self.global_step |
|
|
) |
|
|
|
|
|
if self.encoder_visualizer is not None: |
|
|
for k, image in self.encoder_visualizer.visualize( |
|
|
batch["context"], self.global_step |
|
|
).items(): |
|
|
self.logger.log_image(k, [prep_image(image)], step=self.global_step) |
|
|
|
|
|
|
|
|
if not self.train_cfg.no_viz_video: |
|
|
self.render_video_interpolation(batch) |
|
|
self.render_video_wobble(batch) |
|
|
if self.train_cfg.extended_visualization: |
|
|
self.render_video_interpolation_exaggerated(batch) |
|
|
|
|
|
def on_validation_epoch_end(self) -> None: |
|
|
"""hack to run the full validation""" |
|
|
if self.trainer.sanity_checking and self.global_rank == 0: |
|
|
print(self.encoder) |
|
|
|
|
|
if (not self.trainer.sanity_checking) and (self.eval_data_cfg is not None): |
|
|
self.eval_cnt = self.eval_cnt + 1 |
|
|
if self.eval_cnt % self.train_cfg.eval_model_every_n_val == 0: |
|
|
|
|
|
if self.train_cfg.eval_save_model: |
|
|
ckpt_saved_path = ( |
|
|
self.trainer.checkpoint_callback.format_checkpoint_name( |
|
|
dict( |
|
|
epoch=self.trainer.current_epoch, |
|
|
step=self.trainer.global_step, |
|
|
) |
|
|
) |
|
|
) |
|
|
backup_dir = str( |
|
|
Path(ckpt_saved_path).parent.parent / "checkpoints_backups" |
|
|
) |
|
|
if self.global_rank == 0: |
|
|
os.makedirs(backup_dir, exist_ok=True) |
|
|
ckpt_saved_path = os.path.join( |
|
|
backup_dir, os.path.basename(ckpt_saved_path) |
|
|
) |
|
|
|
|
|
self.trainer.save_checkpoint( |
|
|
ckpt_saved_path, |
|
|
weights_only=True, |
|
|
) |
|
|
if self.global_rank == 0: |
|
|
print(f"backup model to {ckpt_saved_path}.") |
|
|
|
|
|
|
|
|
self.run_full_test_sets_eval() |
|
|
|
|
|
@rank_zero_only |
|
|
def run_full_test_sets_eval(self) -> None: |
|
|
start_t = time.time() |
|
|
|
|
|
pred_depths = None |
|
|
depth_gt = None |
|
|
|
|
|
full_testsets = self.trainer.datamodule.test_dataloader( |
|
|
|
|
|
) |
|
|
scores_dict = {} |
|
|
|
|
|
if not self.train_cfg.forward_depth_only: |
|
|
for score_tag in ("psnr", "ssim", "lpips"): |
|
|
scores_dict[score_tag] = {} |
|
|
for method_tag in ("deterministic", "probabilistic"): |
|
|
scores_dict[score_tag][method_tag] = [] |
|
|
|
|
|
|
|
|
if self.train_cfg.viz_depth: |
|
|
for score_tag in ("abs_rel", "rmse", "a1"): |
|
|
scores_dict[score_tag] = {} |
|
|
for method_tag in ("deterministic", "probabilistic"): |
|
|
scores_dict[score_tag][method_tag] = [] |
|
|
|
|
|
self.benchmarker.clear_history() |
|
|
time_skip_first_n_steps = min( |
|
|
self.train_cfg.eval_time_skip_steps, len(full_testsets) |
|
|
) |
|
|
time_skip_steps_dict = {"encoder": 0, "decoder": 0} |
|
|
for batch_idx, batch in tqdm( |
|
|
enumerate(full_testsets), |
|
|
total=min(len(full_testsets), self.train_cfg.eval_data_length), |
|
|
): |
|
|
if batch_idx >= self.train_cfg.eval_data_length: |
|
|
break |
|
|
|
|
|
batch = self.data_shim(batch) |
|
|
batch = self.transfer_batch_to_device(batch, "cuda", dataloader_idx=0) |
|
|
|
|
|
|
|
|
b, v, _, h, w = batch["target"]["image"].shape |
|
|
assert b == 1 |
|
|
if batch_idx < time_skip_first_n_steps: |
|
|
time_skip_steps_dict["encoder"] += 1 |
|
|
time_skip_steps_dict["decoder"] += v |
|
|
|
|
|
with self.benchmarker.time("encoder"): |
|
|
gaussians_probabilistic = self.encoder( |
|
|
batch["context"], |
|
|
self.global_step, |
|
|
deterministic=False, |
|
|
) |
|
|
|
|
|
if isinstance(gaussians_probabilistic, dict): |
|
|
pred_depths = gaussians_probabilistic["depths"] |
|
|
if "depth" in batch["context"]: |
|
|
depth_gt = batch["context"]["depth"] |
|
|
gaussians_probabilistic = gaussians_probabilistic["gaussians"] |
|
|
|
|
|
if not self.train_cfg.forward_depth_only: |
|
|
with self.benchmarker.time("decoder", num_calls=v): |
|
|
output_probabilistic = self.decoder.forward( |
|
|
gaussians_probabilistic, |
|
|
batch["target"]["extrinsics"], |
|
|
batch["target"]["intrinsics"], |
|
|
batch["target"]["near"], |
|
|
batch["target"]["far"], |
|
|
(h, w), |
|
|
) |
|
|
rgbs = [output_probabilistic.color[0]] |
|
|
tags = ["probabilistic"] |
|
|
|
|
|
if self.train_cfg.eval_deterministic: |
|
|
gaussians_deterministic = self.encoder( |
|
|
batch["context"], |
|
|
self.global_step, |
|
|
deterministic=True, |
|
|
) |
|
|
output_deterministic = self.decoder.forward( |
|
|
gaussians_deterministic, |
|
|
batch["target"]["extrinsics"], |
|
|
batch["target"]["intrinsics"], |
|
|
batch["target"]["near"], |
|
|
batch["target"]["far"], |
|
|
(h, w), |
|
|
) |
|
|
rgbs.append(output_deterministic.color[0]) |
|
|
tags.append("deterministic") |
|
|
|
|
|
|
|
|
rgb_gt = batch["target"]["image"][0] |
|
|
for tag, rgb in zip(tags, rgbs): |
|
|
scores_dict["psnr"][tag].append( |
|
|
compute_psnr(rgb_gt, rgb).mean().item() |
|
|
) |
|
|
scores_dict["lpips"][tag].append( |
|
|
compute_lpips(rgb_gt, rgb).mean().item() |
|
|
) |
|
|
scores_dict["ssim"][tag].append( |
|
|
compute_ssim(rgb_gt, rgb).mean().item() |
|
|
) |
|
|
|
|
|
|
|
|
for score_tag, methods in scores_dict.items(): |
|
|
for method_tag, cur_scores in methods.items(): |
|
|
if len(cur_scores) > 0: |
|
|
cur_mean = sum(cur_scores) / len(cur_scores) |
|
|
self.log(f"test/{score_tag}", cur_mean) |
|
|
|
|
|
for tag, times in self.benchmarker.execution_times.items(): |
|
|
times = times[int(time_skip_steps_dict[tag]) :] |
|
|
print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call") |
|
|
self.log(f"test/runtime_avg_{tag}", np.mean(times)) |
|
|
self.benchmarker.clear_history() |
|
|
|
|
|
overall_eval_time = time.time() - start_t |
|
|
print(f"Eval total time cost: {overall_eval_time:.3f}s") |
|
|
self.log("test/runtime_all", overall_eval_time) |
|
|
|
|
|
@rank_zero_only |
|
|
def render_video_wobble(self, batch: BatchedExample) -> None: |
|
|
|
|
|
_, v, _, _ = batch["context"]["extrinsics"].shape |
|
|
if v != 2: |
|
|
return |
|
|
|
|
|
def trajectory_fn(t): |
|
|
origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] |
|
|
origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] |
|
|
delta = (origin_a - origin_b).norm(dim=-1) |
|
|
extrinsics = generate_wobble( |
|
|
batch["context"]["extrinsics"][:, 0], |
|
|
delta * 0.25, |
|
|
t, |
|
|
) |
|
|
intrinsics = repeat( |
|
|
batch["context"]["intrinsics"][:, 0], |
|
|
"b i j -> b v i j", |
|
|
v=t.shape[0], |
|
|
) |
|
|
return extrinsics, intrinsics |
|
|
|
|
|
return self.render_video_generic(batch, trajectory_fn, "wobble", num_frames=60) |
|
|
|
|
|
@rank_zero_only |
|
|
def render_video_interpolation(self, batch: BatchedExample) -> None: |
|
|
_, v, _, _ = batch["context"]["extrinsics"].shape |
|
|
|
|
|
def trajectory_fn(t): |
|
|
extrinsics = interpolate_extrinsics( |
|
|
batch["context"]["extrinsics"][0, 0], |
|
|
( |
|
|
batch["context"]["extrinsics"][0, 1] |
|
|
if v == 2 |
|
|
else batch["target"]["extrinsics"][0, 0] |
|
|
), |
|
|
t, |
|
|
) |
|
|
intrinsics = interpolate_intrinsics( |
|
|
batch["context"]["intrinsics"][0, 0], |
|
|
( |
|
|
batch["context"]["intrinsics"][0, 1] |
|
|
if v == 2 |
|
|
else batch["target"]["intrinsics"][0, 0] |
|
|
), |
|
|
t, |
|
|
) |
|
|
return extrinsics[None], intrinsics[None] |
|
|
|
|
|
return self.render_video_generic(batch, trajectory_fn, "rgb") |
|
|
|
|
|
@rank_zero_only |
|
|
def render_video_interpolation_exaggerated(self, batch: BatchedExample) -> None: |
|
|
|
|
|
_, v, _, _ = batch["context"]["extrinsics"].shape |
|
|
if v != 2: |
|
|
return |
|
|
|
|
|
def trajectory_fn(t): |
|
|
origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] |
|
|
origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] |
|
|
delta = (origin_a - origin_b).norm(dim=-1) |
|
|
tf = generate_wobble_transformation( |
|
|
delta * 0.5, |
|
|
t, |
|
|
5, |
|
|
scale_radius_with_t=False, |
|
|
) |
|
|
extrinsics = interpolate_extrinsics( |
|
|
batch["context"]["extrinsics"][0, 0], |
|
|
( |
|
|
batch["context"]["extrinsics"][0, 1] |
|
|
if v == 2 |
|
|
else batch["target"]["extrinsics"][0, 0] |
|
|
), |
|
|
t * 5 - 2, |
|
|
) |
|
|
intrinsics = interpolate_intrinsics( |
|
|
batch["context"]["intrinsics"][0, 0], |
|
|
( |
|
|
batch["context"]["intrinsics"][0, 1] |
|
|
if v == 2 |
|
|
else batch["target"]["intrinsics"][0, 0] |
|
|
), |
|
|
t * 5 - 2, |
|
|
) |
|
|
return extrinsics @ tf, intrinsics[None] |
|
|
|
|
|
return self.render_video_generic( |
|
|
batch, |
|
|
trajectory_fn, |
|
|
"interpolation_exagerrated", |
|
|
num_frames=300, |
|
|
smooth=False, |
|
|
loop_reverse=False, |
|
|
) |
|
|
|
|
|
@rank_zero_only |
|
|
def render_video_generic( |
|
|
self, |
|
|
batch: BatchedExample, |
|
|
trajectory_fn: TrajectoryFn, |
|
|
name: str, |
|
|
num_frames: int = 30, |
|
|
smooth: bool = True, |
|
|
loop_reverse: bool = True, |
|
|
) -> None: |
|
|
|
|
|
gaussians_prob = self.encoder(batch["context"], self.global_step, False) |
|
|
|
|
|
|
|
|
if isinstance(gaussians_prob, dict): |
|
|
gaussians_prob = gaussians_prob["gaussians"] |
|
|
|
|
|
t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=self.device) |
|
|
if smooth: |
|
|
t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 |
|
|
|
|
|
extrinsics, intrinsics = trajectory_fn(t) |
|
|
|
|
|
_, _, _, h, w = batch["context"]["image"].shape |
|
|
|
|
|
|
|
|
def depth_map(result): |
|
|
near = result[result > 0][:16_000_000].quantile(0.01).log() |
|
|
far = result.view(-1)[:16_000_000].quantile(0.99).log() |
|
|
result = result.log() |
|
|
result = 1 - (result - near) / (far - near) |
|
|
return apply_color_map_to_image(result, "turbo") |
|
|
|
|
|
near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames) |
|
|
far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames) |
|
|
output_prob = self.decoder.forward( |
|
|
gaussians_prob, extrinsics, intrinsics, near, far, (h, w), "depth" |
|
|
) |
|
|
images_prob = [ |
|
|
vcat(rgb, depth) |
|
|
for rgb, depth in zip(output_prob.color[0], depth_map(output_prob.depth[0])) |
|
|
] |
|
|
|
|
|
images = [ |
|
|
add_border( |
|
|
hcat( |
|
|
add_label(image_prob, "Prediction"), |
|
|
) |
|
|
) |
|
|
for image_prob, _ in zip(images_prob, images_prob) |
|
|
] |
|
|
|
|
|
video = torch.stack(images) |
|
|
video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy() |
|
|
if loop_reverse: |
|
|
video = pack([video, video[::-1][1:-1]], "* c h w")[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
if self.optimizer_cfg.lr_monodepth > 0: |
|
|
pretrained_params = [] |
|
|
new_params = [] |
|
|
|
|
|
for name, param in self.named_parameters(): |
|
|
if "pretrained" in name: |
|
|
pretrained_params.append(param) |
|
|
else: |
|
|
new_params.append(param) |
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
[ |
|
|
{ |
|
|
"params": pretrained_params, |
|
|
"lr": self.optimizer_cfg.lr_monodepth, |
|
|
}, |
|
|
{"params": new_params, "lr": self.optimizer_cfg.lr}, |
|
|
], |
|
|
weight_decay=self.optimizer_cfg.weight_decay, |
|
|
) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
|
optimizer, |
|
|
[self.optimizer_cfg.lr_monodepth, self.optimizer_cfg.lr], |
|
|
self.trainer.max_steps + 10, |
|
|
pct_start=0.01, |
|
|
cycle_momentum=False, |
|
|
anneal_strategy="cos", |
|
|
) |
|
|
|
|
|
else: |
|
|
optimizer = optim.AdamW( |
|
|
self.parameters(), |
|
|
lr=self.optimizer_cfg.lr, |
|
|
weight_decay=self.optimizer_cfg.weight_decay, |
|
|
) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
|
optimizer, |
|
|
self.optimizer_cfg.lr, |
|
|
self.trainer.max_steps + 10, |
|
|
pct_start=0.01, |
|
|
cycle_momentum=False, |
|
|
anneal_strategy="cos", |
|
|
) |
|
|
|
|
|
return { |
|
|
"optimizer": optimizer, |
|
|
"lr_scheduler": { |
|
|
"scheduler": scheduler, |
|
|
"interval": "step", |
|
|
"frequency": 1, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def snapshot_params(self): |
|
|
"""Return a snapshot dict name->cpu clone for all trainable params.""" |
|
|
return {n: p.detach().cpu().clone() for n, p in self.named_parameters() if p.requires_grad} |
|
|
|
|
|
def compare_params_snapshot(self, before_snap, tol=0.0, show_n=50): |
|
|
"""Compare saved snapshot with current params; report params with zero update.""" |
|
|
if before_snap is None: |
|
|
print("[CHECK] no before-snapshot provided") |
|
|
return |
|
|
zero_updates = [] |
|
|
small_updates = [] |
|
|
for name, p in self.named_parameters(): |
|
|
if not p.requires_grad: |
|
|
continue |
|
|
if name not in before_snap: |
|
|
continue |
|
|
before = before_snap[name] |
|
|
cur = p.detach().cpu() |
|
|
diff = (cur - before).abs().sum().item() |
|
|
if diff <= tol: |
|
|
zero_updates.append((name, diff)) |
|
|
else: |
|
|
if diff < 1e-12: |
|
|
small_updates.append((name, diff)) |
|
|
print(f"[CHECK] zero-updates={len(zero_updates)}, very-small-updates={len(small_updates)}") |
|
|
if zero_updates: |
|
|
print(" params with (nearly) zero update (first {}):".format(show_n)) |
|
|
for n, d in zero_updates[:show_n]: |
|
|
print(f" {n} diff_sum={d:.3e}") |
|
|
|
|
|
def list_params_with_no_grad_after_backward(self, tiny_thresh=1e-12, show_n=50): |
|
|
"""Call after backward: list params with grad None or extremely small norm.""" |
|
|
none_grad = [] |
|
|
very_small = [] |
|
|
nan_inf = [] |
|
|
for name, p in self.named_parameters(): |
|
|
if not p.requires_grad: |
|
|
continue |
|
|
g = p.grad |
|
|
if g is None: |
|
|
none_grad.append(name) |
|
|
else: |
|
|
if torch.isnan(g).any() or torch.isinf(g).any(): |
|
|
nan_inf.append(name) |
|
|
try: |
|
|
norm = float(g.norm().item()) |
|
|
except Exception: |
|
|
norm = 0.0 |
|
|
if norm < tiny_thresh: |
|
|
very_small.append((name, norm)) |
|
|
print(f"[CHECK] grads None: {len(none_grad)}, very_small (<{tiny_thresh}): {len(very_small)}, nan/inf: {len(nan_inf)}") |
|
|
if none_grad: |
|
|
print(" grads is None (first {}):".format(show_n)) |
|
|
for n in none_grad[:show_n]: |
|
|
print(" ", n) |
|
|
if very_small: |
|
|
print(" grads very small (first {}):".format(show_n)) |
|
|
for n, norm in very_small[:show_n]: |
|
|
print(f" {n} norm={norm:.3e}") |
|
|
if nan_inf: |
|
|
print(" grads have NaN/Inf (first {}):".format(show_n)) |
|
|
for n in nan_inf[:show_n]: |
|
|
print(" ", n) |
|
|
|
|
|
def list_frozen_params(self, show_n=50): |
|
|
frozen = [(n, p.shape) for n, p in self.named_parameters() if not p.requires_grad] |
|
|
print(f"[CHECK] frozen params count = {len(frozen)}") |
|
|
for name, shape in frozen[:show_n]: |
|
|
print(" ", name, shape) |
|
|
|
|
|
def check_optimizer_coverage(self): |
|
|
"""Check whether configured optimizer(s) include all trainable params""" |
|
|
try: |
|
|
opt_or_list = self.optimizers() |
|
|
except Exception as e: |
|
|
|
|
|
print("[CHECK] cannot access optimizer from here:", e) |
|
|
return |
|
|
opts = opt_or_list if isinstance(opt_or_list, (list, tuple)) else [opt_or_list] |
|
|
opt_param_ids = set(id(p) for o in opts for g in o.param_groups for p in g['params']) |
|
|
missing = [n for n, p in self.named_parameters() if p.requires_grad and id(p) not in opt_param_ids] |
|
|
print(f"[CHECK] optimizer missing params count = {len(missing)} (show up to 50):") |
|
|
for n in missing[:50]: |
|
|
print(" ", n) |
|
|
print(" optimizer lrs per opt:", [[g['lr'] for g in o.param_groups] for o in opts]) |
|
|
|
|
|
|
|
|
def on_after_backward(self): |
|
|
"""Called by Lightning after backward(): print grad stats (only if debug enabled).""" |
|
|
if not getattr(self, "_check_param_updates", False): |
|
|
return |
|
|
try: |
|
|
print(f"[on_after_backward] step={self.global_step}") |
|
|
|
|
|
self.list_params_with_no_grad_after_backward() |
|
|
except Exception as e: |
|
|
print("on_after_backward debug error:", e) |
|
|
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): |
|
|
"""Called after optimizer.step() in Lightning — compare params snapshot here.""" |
|
|
|
|
|
if not getattr(self, "_check_param_updates", False): |
|
|
return |
|
|
try: |
|
|
|
|
|
before = getattr(self, "_before_params_snapshot", None) |
|
|
if before is None: |
|
|
print("[on_train_batch_end] no before snapshot found") |
|
|
else: |
|
|
self.compare_params_snapshot(before, tol=0.0) |
|
|
|
|
|
self._before_params_snapshot = None |
|
|
|
|
|
self.check_optimizer_coverage() |
|
|
except Exception as e: |
|
|
print("on_train_batch_end debug error:", e) |