""" Training/inference wrapper for LiDAR-Perfect Depth. Mirrors `ppd.models.ppd_train.PixelPerfectDepth` but: * substitutes DiT for LPDDiT (adds sparse-prompt path) * simulates sparse-LiDAR observations on each training batch via `sparse_simulator.simulate` * adds the anchor-consistency loss to the velocity-MSE + grad loss * `forward_test` runs the Kalman-in-loop sampler with posterior projection The DiT backbone (everything except prompt encoder + gate) can be optionally frozen via `cfg.freeze_backbone` — paper §3.6 reports the full framework trains fewer than 1% of parameters in this regime. """ from __future__ import annotations import os from omegaconf import DictConfig import torch import torch.nn as nn import torch.nn.functional as F from ppd.utils.diffusion.timesteps import Timesteps from ppd.utils.diffusion.schedule import LinearSchedule from ppd.utils.diffusion.sampler import EulerSampler from ppd.utils.diffusion.logitnormal import LogitNormalTrainingTimesteps from ppd.models.depth_anything_v2.dpt import DepthAnythingV2 from ppd.models.loss import multi_scale_grad_loss from ppd.lpd.lpd_dit import LPDDiT from ppd.lpd.sparse_simulator import simulate, random_pattern_choice from ppd.lpd.losses import anchor_loss from ppd.lpd.kalman_in_loop import kalman_in_loop_sample, KalmanInLoopConfig def _device() -> torch.device: return torch.device("cuda", int(os.environ.get("LOCAL_RANK", "0"))) class LiDARPerfectDepth(nn.Module): """LPD trainer/inferencer.""" def __init__(self, config: DictConfig): super().__init__() self.config = config self.configure_diffusion() # Semantics encoder (frozen, identical to PPD) if config.semantics_model == "MoGe2": from ppd.moge.model.v2 import MoGeModel self.sem_encoder = MoGeModel.from_pretrained(config.semantics_pth) else: self.sem_encoder = DepthAnythingV2( encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024] ) self.sem_encoder.load_state_dict( torch.load(config.semantics_pth, map_location="cpu"), strict=False ) self.sem_encoder = self.sem_encoder.to(_device()).eval() self.sem_encoder.requires_grad_(False) # LPD-DiT replaces the vanilla DiT self.dit = LPDDiT( in_channels=config.score_model.get("in_channels", 4), out_channels=config.score_model.get("out_channels", 1), hidden_size=config.score_model.get("hidden_size", 1024), depth=config.score_model.get("depth", 24), num_heads=config.score_model.get("num_heads", 16), patch_size=config.score_model.get("patch_size", 8), mlp_ratio=config.score_model.get("mlp_ratio", 4.0), prompt_scales=tuple(config.get("prompt_scales", (4, 8, 16, 32))), prompt_hidden=config.get("prompt_hidden", 128), ) # Optionally load PPD-pretrained DiT weights (everything except the new # sparse-prompt branch; load_state_dict with strict=False). ppd_weights = config.get("ppd_weights", None) if ppd_weights and os.path.exists(ppd_weights): self._load_ppd_weights(ppd_weights) if config.get("freeze_backbone", True): self.dit.freeze_backbone() # ------------------------------------------------------------------ setup def _load_ppd_weights(self, path: str) -> None: sd = torch.load(path, map_location="cpu") if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"] # Strip any pipeline.dit. prefix or dit. prefix cleaned = {} for k, v in sd.items(): for prefix in ("pipeline.dit.", "dit."): if k.startswith(prefix): cleaned[k[len(prefix):]] = v break if not cleaned: cleaned = sd missing, unexpected = self.dit.load_state_dict(cleaned, strict=False) # Expect 'sparse_prompt_encoder.*' and 'prompt_gate.*' to be missing # (they're new modules), and nothing to be unexpected. if any( not (k.startswith("sparse_prompt_encoder") or k.startswith("prompt_gate")) for k in missing ): print(f"[LPD] Missing keys when loading PPD weights: {missing[:5]}...") if unexpected: print(f"[LPD] Unexpected keys: {unexpected[:5]}...") def configure_diffusion(self) -> None: self.schedule = LinearSchedule(T=1000) self.sampling_timesteps = Timesteps( T=self.schedule.T, steps=self.config.diffusion.timesteps.sampling.steps, device=_device(), ) self.sampler = EulerSampler( schedule=self.schedule, timesteps=self.sampling_timesteps, prediction_type="velocity", ) self.training_timesteps = LogitNormalTrainingTimesteps( T=self.schedule.T, loc=self.config.diffusion.timesteps.training.loc, scale=self.config.diffusion.timesteps.training.scale, ) # ----------------------------------------------------------------- helpers @torch.no_grad() def get_cond(self, img: torch.Tensor) -> torch.Tensor: return img - 0.5 @torch.no_grad() def semantics_prompt(self, image: torch.Tensor) -> torch.Tensor: return self.sem_encoder.forward_semantics(image) @torch.no_grad() def get_gt(self, batch: dict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Returns (latent, mask, log_min_used, log_max_used). The min/max are returned so we can normalize the simulated sparse observation into the same space the DiT predicts in. """ depth = batch["depth"] mask = batch["mask"].bool() B = depth.shape[0] clip_mask = mask & (depth < 80.0) log_depth = torch.log(depth + 1.0) min_vals, max_vals = [], [] for i in range(B): i_d, i_m = log_depth[i], clip_mask[i] if i_m.sum() == 0: min_vals.append(torch.tensor(0.0, device=depth.device)) max_vals.append(torch.tensor(1.0, device=depth.device)) continue vals = i_d[i_m] min_vals.append(torch.quantile(vals, 0.02)) max_vals.append(torch.quantile(vals, 0.98)) min_v = torch.stack(min_vals)[:, None, None, None] max_v = torch.stack(max_vals)[:, None, None, None] invalid = (max_v - min_v) < 1e-6 max_v = torch.where(invalid, min_v + 1e-6, max_v) norm = (log_depth - min_v) / (max_v - min_v) norm = torch.clamp(norm, -0.5, 1.0) - 0.5 return norm, mask, min_v, max_v @torch.no_grad() def normalize_sparse( self, sparse_depth: torch.Tensor, sparse_mask: torch.Tensor, log_min: torch.Tensor, log_max: torch.Tensor, ) -> torch.Tensor: """Apply the same per-sample log-quantile normalization as `get_gt`.""" log_d = torch.log(sparse_depth + 1.0) norm = (log_d - log_min) / (log_max - log_min) - 0.5 norm = torch.clamp(norm, -0.5, 1.0) return norm * sparse_mask.float() @torch.no_grad() def simulate_sparse_for_batch( self, depth: torch.Tensor, mask: torch.Tensor, cfg: DictConfig | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Wraps `simulate` with sensible defaults from config.""" scfg = (cfg or self.config).get("sparse", {}) pattern = scfg.get("pattern", "auto") if pattern == "auto": pattern = random_pattern_choice() return simulate( depth, mask, pattern=pattern, density=scfg.get("density", 0.005), n_lines=scfg.get("n_lines", 64), line_density=scfg.get("line_density", 0.5), grid_stride=scfg.get("grid_stride", 32), min_points=scfg.get("min_points", 16), measurement_noise_std=scfg.get("measurement_noise_std", 0.0), ) # ------------------------------------------------------------------- train def forward_train(self, batch: dict) -> dict: B = batch["image"].shape[0] cond = self.get_cond(batch["image"]) latent, mask, log_min, log_max = self.get_gt(batch) semantics = self.semantics_prompt(batch["image"]) # Simulate sparse-LiDAR from dense GT, then normalize into latent space sparse_depth_metric, sparse_mask = self.simulate_sparse_for_batch( batch["depth"], batch["mask"].bool() ) sparse_depth_norm = self.normalize_sparse( sparse_depth_metric, sparse_mask, log_min, log_max ) # Diffusion forward: noise the GT latent noises = torch.randn_like(latent) timesteps = self.training_timesteps.sample([B], device=_device()) latent_noised = self.schedule.forward(latent, noises, timesteps) x = torch.cat([latent_noised, cond], dim=1) pred = self.dit( x=x, semantics=semantics, timestep=timesteps, sparse_depth=sparse_depth_norm, sparse_mask=sparse_mask, ) latent_pred, noises_pred = self.schedule.convert_from_pred( pred=pred, pred_type="velocity", x_t=latent_noised, t=timesteps ) loss_input = self.schedule.convert_to_pred( x_0=latent_pred, x_T=noises_pred, t=timesteps, pred_type="velocity" ) loss_target = self.schedule.convert_to_pred( x_0=latent, x_T=noises, t=timesteps, pred_type="velocity" ) mse = F.mse_loss(loss_input, loss_target, reduction="none") * mask.float() mse = mse.sum() / (mask.float().sum() + 1e-6) loss = mse # Anchor consistency loss in normalized space lambda_anchor = float(self.config.get("lambda_anchor", 0.5)) if lambda_anchor > 0: anc = anchor_loss(latent_pred, sparse_depth_norm, sparse_mask) loss = loss + lambda_anchor * anc # Multi-scale gradient loss (fine-tuning only, mirrors PPD) if not self.config.get("pretrain", False): grad = multi_scale_grad_loss( latent_pred.squeeze(1), latent.squeeze(1), mask.float().squeeze(1) ) loss = loss + 0.2 * grad return { "loss": loss, "depth": latent_pred + 0.5, "image": batch["image"], "sparse_mask": sparse_mask, } # ------------------------------------------------------------------ infer @torch.no_grad() def forward_test(self, batch: dict) -> dict: ori_h, ori_w = batch["image"].shape[-2:] target_area = 1024 * 768 if not self.config.get("pretrain", False) else 512 * 512 scale = (target_area / (ori_w * ori_h)) ** 0.5 new_h = max(16, int(round(ori_h * scale / 16)) * 16) new_w = max(16, int(round(ori_w * scale / 16)) * 16) image = F.interpolate(batch["image"], size=(new_h, new_w), mode="bilinear", align_corners=False) cond = self.get_cond(image) semantics = self.semantics_prompt(image) # Sparse observations: from batch if provided, else simulate from dense GT if "sparse_depth" in batch and "sparse_mask" in batch: sparse_depth_metric = F.interpolate( batch["sparse_depth"], size=(new_h, new_w), mode="nearest" ) sparse_mask = F.interpolate( batch["sparse_mask"].float(), size=(new_h, new_w), mode="nearest" ).bool() elif "depth" in batch and "mask" in batch: depth_resized = F.interpolate( batch["depth"], size=(new_h, new_w), mode="nearest" ) mask_resized = F.interpolate( batch["mask"].float(), size=(new_h, new_w), mode="nearest" ).bool() sparse_depth_metric, sparse_mask = self.simulate_sparse_for_batch( depth_resized, mask_resized ) else: sparse_depth_metric = torch.zeros(image.shape[0], 1, new_h, new_w, device=image.device) sparse_mask = torch.zeros_like(sparse_depth_metric, dtype=torch.bool) # Normalize sparse: use a coarse min/max from the sparse observations # themselves (no GT depth available at inference). log_d = torch.log(sparse_depth_metric.clamp_min(0.0) + 1.0) B = image.shape[0] log_min, log_max = [], [] for i in range(B): m = sparse_mask[i].bool() if m.sum() == 0: log_min.append(torch.tensor(0.0, device=image.device)) log_max.append(torch.tensor(1.0, device=image.device)) continue vals = log_d[i][m] log_min.append(torch.quantile(vals, 0.02)) log_max.append(torch.quantile(vals, 0.98)) log_min_t = torch.stack(log_min)[:, None, None, None] log_max_t = torch.stack(log_max)[:, None, None, None] invalid = (log_max_t - log_min_t) < 1e-6 log_max_t = torch.where(invalid, log_min_t + 1e-6, log_max_t) sparse_depth_norm = self.normalize_sparse( sparse_depth_metric, sparse_mask, log_min_t, log_max_t ) # Kalman-in-loop sampling x_T = torch.randn(B, 1, new_h, new_w, device=image.device) def predict_x0(x_tau: torch.Tensor, tau: torch.Tensor) -> torch.Tensor: inp = torch.cat([x_tau, cond], dim=1) v_pred = self.dit( x=inp, semantics=semantics, timestep=tau, sparse_depth=sparse_depth_norm, sparse_mask=sparse_mask, ) x0, _ = self.schedule.convert_from_pred( pred=v_pred, pred_type="velocity", x_t=x_tau, t=tau ) return x0 kfg = KalmanInLoopConfig( R_proj=self.config.get("R_proj", 0.1), proj_alpha=self.config.get("proj_alpha", 0.1), init_P=self.config.get("init_P", 1.0), ) latent, P_final = kalman_in_loop_sample( dit_predict_x0=predict_x0, sampler=self.sampler, timesteps=list(self.sampling_timesteps), x_T=x_T, cond=cond, semantics_fn=lambda: semantics, sparse_depth=sparse_depth_norm, sparse_mask=sparse_mask, mu_temporal=batch.get("kalman_mu_prior"), P_temporal=batch.get("kalman_P_prior"), config=kfg, ) depth = latent + 0.5 depth = F.interpolate(depth, size=(ori_h, ori_w), mode="nearest") return { "depth": depth, "image": batch["image"], "kalman_variance": P_final, }