| """ |
| Training/inference wrapper for LiDAR-Perfect Depth. |
| |
| Mirrors `ppd.models.ppd_train.PixelPerfectDepth` but: |
| * substitutes DiT for LPDDiT (adds sparse-prompt path) |
| * simulates sparse-LiDAR observations on each training batch via |
| `sparse_simulator.simulate` |
| * adds the anchor-consistency loss to the velocity-MSE + grad loss |
| * `forward_test` runs the Kalman-in-loop sampler with posterior projection |
| |
| The DiT backbone (everything except prompt encoder + gate) can be optionally |
| frozen via `cfg.freeze_backbone` — paper §3.6 reports the full framework |
| trains fewer than 1% of parameters in this regime. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| from omegaconf import DictConfig |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from ppd.utils.diffusion.timesteps import Timesteps |
| from ppd.utils.diffusion.schedule import LinearSchedule |
| from ppd.utils.diffusion.sampler import EulerSampler |
| from ppd.utils.diffusion.logitnormal import LogitNormalTrainingTimesteps |
|
|
| from ppd.models.depth_anything_v2.dpt import DepthAnythingV2 |
| from ppd.models.loss import multi_scale_grad_loss |
|
|
| from ppd.lpd.lpd_dit import LPDDiT |
| from ppd.lpd.sparse_simulator import simulate, random_pattern_choice |
| from ppd.lpd.losses import anchor_loss |
| from ppd.lpd.kalman_in_loop import kalman_in_loop_sample, KalmanInLoopConfig |
|
|
|
|
| def _device() -> torch.device: |
| return torch.device("cuda", int(os.environ.get("LOCAL_RANK", "0"))) |
|
|
|
|
| class LiDARPerfectDepth(nn.Module): |
| """LPD trainer/inferencer.""" |
|
|
| def __init__(self, config: DictConfig): |
| super().__init__() |
| self.config = config |
| self.configure_diffusion() |
|
|
| |
| if config.semantics_model == "MoGe2": |
| from ppd.moge.model.v2 import MoGeModel |
| self.sem_encoder = MoGeModel.from_pretrained(config.semantics_pth) |
| else: |
| self.sem_encoder = DepthAnythingV2( |
| encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024] |
| ) |
| self.sem_encoder.load_state_dict( |
| torch.load(config.semantics_pth, map_location="cpu"), strict=False |
| ) |
| self.sem_encoder = self.sem_encoder.to(_device()).eval() |
| self.sem_encoder.requires_grad_(False) |
|
|
| |
| self.dit = LPDDiT( |
| in_channels=config.score_model.get("in_channels", 4), |
| out_channels=config.score_model.get("out_channels", 1), |
| hidden_size=config.score_model.get("hidden_size", 1024), |
| depth=config.score_model.get("depth", 24), |
| num_heads=config.score_model.get("num_heads", 16), |
| patch_size=config.score_model.get("patch_size", 8), |
| mlp_ratio=config.score_model.get("mlp_ratio", 4.0), |
| prompt_scales=tuple(config.get("prompt_scales", (4, 8, 16, 32))), |
| prompt_hidden=config.get("prompt_hidden", 128), |
| ) |
|
|
| |
| |
| ppd_weights = config.get("ppd_weights", None) |
| if ppd_weights and os.path.exists(ppd_weights): |
| self._load_ppd_weights(ppd_weights) |
|
|
| if config.get("freeze_backbone", True): |
| self.dit.freeze_backbone() |
|
|
| |
| def _load_ppd_weights(self, path: str) -> None: |
| sd = torch.load(path, map_location="cpu") |
| if isinstance(sd, dict) and "state_dict" in sd: |
| sd = sd["state_dict"] |
| |
| cleaned = {} |
| for k, v in sd.items(): |
| for prefix in ("pipeline.dit.", "dit."): |
| if k.startswith(prefix): |
| cleaned[k[len(prefix):]] = v |
| break |
| if not cleaned: |
| cleaned = sd |
| missing, unexpected = self.dit.load_state_dict(cleaned, strict=False) |
| |
| |
| if any( |
| not (k.startswith("sparse_prompt_encoder") or k.startswith("prompt_gate")) |
| for k in missing |
| ): |
| print(f"[LPD] Missing keys when loading PPD weights: {missing[:5]}...") |
| if unexpected: |
| print(f"[LPD] Unexpected keys: {unexpected[:5]}...") |
|
|
| def configure_diffusion(self) -> None: |
| self.schedule = LinearSchedule(T=1000) |
| self.sampling_timesteps = Timesteps( |
| T=self.schedule.T, |
| steps=self.config.diffusion.timesteps.sampling.steps, |
| device=_device(), |
| ) |
| self.sampler = EulerSampler( |
| schedule=self.schedule, |
| timesteps=self.sampling_timesteps, |
| prediction_type="velocity", |
| ) |
| self.training_timesteps = LogitNormalTrainingTimesteps( |
| T=self.schedule.T, |
| loc=self.config.diffusion.timesteps.training.loc, |
| scale=self.config.diffusion.timesteps.training.scale, |
| ) |
|
|
| |
| @torch.no_grad() |
| def get_cond(self, img: torch.Tensor) -> torch.Tensor: |
| return img - 0.5 |
|
|
| @torch.no_grad() |
| def semantics_prompt(self, image: torch.Tensor) -> torch.Tensor: |
| return self.sem_encoder.forward_semantics(image) |
|
|
| @torch.no_grad() |
| def get_gt(self, batch: dict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Returns (latent, mask, log_min_used, log_max_used). |
| |
| The min/max are returned so we can normalize the simulated sparse |
| observation into the same space the DiT predicts in. |
| """ |
| depth = batch["depth"] |
| mask = batch["mask"].bool() |
| B = depth.shape[0] |
| clip_mask = mask & (depth < 80.0) |
| log_depth = torch.log(depth + 1.0) |
| min_vals, max_vals = [], [] |
| for i in range(B): |
| i_d, i_m = log_depth[i], clip_mask[i] |
| if i_m.sum() == 0: |
| min_vals.append(torch.tensor(0.0, device=depth.device)) |
| max_vals.append(torch.tensor(1.0, device=depth.device)) |
| continue |
| vals = i_d[i_m] |
| min_vals.append(torch.quantile(vals, 0.02)) |
| max_vals.append(torch.quantile(vals, 0.98)) |
| min_v = torch.stack(min_vals)[:, None, None, None] |
| max_v = torch.stack(max_vals)[:, None, None, None] |
| invalid = (max_v - min_v) < 1e-6 |
| max_v = torch.where(invalid, min_v + 1e-6, max_v) |
| norm = (log_depth - min_v) / (max_v - min_v) |
| norm = torch.clamp(norm, -0.5, 1.0) - 0.5 |
| return norm, mask, min_v, max_v |
|
|
| @torch.no_grad() |
| def normalize_sparse( |
| self, |
| sparse_depth: torch.Tensor, |
| sparse_mask: torch.Tensor, |
| log_min: torch.Tensor, |
| log_max: torch.Tensor, |
| ) -> torch.Tensor: |
| """Apply the same per-sample log-quantile normalization as `get_gt`.""" |
| log_d = torch.log(sparse_depth + 1.0) |
| norm = (log_d - log_min) / (log_max - log_min) - 0.5 |
| norm = torch.clamp(norm, -0.5, 1.0) |
| return norm * sparse_mask.float() |
|
|
| @torch.no_grad() |
| def simulate_sparse_for_batch( |
| self, |
| depth: torch.Tensor, |
| mask: torch.Tensor, |
| cfg: DictConfig | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Wraps `simulate` with sensible defaults from config.""" |
| scfg = (cfg or self.config).get("sparse", {}) |
| pattern = scfg.get("pattern", "auto") |
| if pattern == "auto": |
| pattern = random_pattern_choice() |
| return simulate( |
| depth, mask, |
| pattern=pattern, |
| density=scfg.get("density", 0.005), |
| n_lines=scfg.get("n_lines", 64), |
| line_density=scfg.get("line_density", 0.5), |
| grid_stride=scfg.get("grid_stride", 32), |
| min_points=scfg.get("min_points", 16), |
| measurement_noise_std=scfg.get("measurement_noise_std", 0.0), |
| ) |
|
|
| |
| def forward_train(self, batch: dict) -> dict: |
| B = batch["image"].shape[0] |
| cond = self.get_cond(batch["image"]) |
| latent, mask, log_min, log_max = self.get_gt(batch) |
| semantics = self.semantics_prompt(batch["image"]) |
|
|
| |
| sparse_depth_metric, sparse_mask = self.simulate_sparse_for_batch( |
| batch["depth"], batch["mask"].bool() |
| ) |
| sparse_depth_norm = self.normalize_sparse( |
| sparse_depth_metric, sparse_mask, log_min, log_max |
| ) |
|
|
| |
| noises = torch.randn_like(latent) |
| timesteps = self.training_timesteps.sample([B], device=_device()) |
| latent_noised = self.schedule.forward(latent, noises, timesteps) |
| x = torch.cat([latent_noised, cond], dim=1) |
|
|
| pred = self.dit( |
| x=x, |
| semantics=semantics, |
| timestep=timesteps, |
| sparse_depth=sparse_depth_norm, |
| sparse_mask=sparse_mask, |
| ) |
|
|
| latent_pred, noises_pred = self.schedule.convert_from_pred( |
| pred=pred, pred_type="velocity", x_t=latent_noised, t=timesteps |
| ) |
| loss_input = self.schedule.convert_to_pred( |
| x_0=latent_pred, x_T=noises_pred, t=timesteps, pred_type="velocity" |
| ) |
| loss_target = self.schedule.convert_to_pred( |
| x_0=latent, x_T=noises, t=timesteps, pred_type="velocity" |
| ) |
| mse = F.mse_loss(loss_input, loss_target, reduction="none") * mask.float() |
| mse = mse.sum() / (mask.float().sum() + 1e-6) |
| loss = mse |
|
|
| |
| lambda_anchor = float(self.config.get("lambda_anchor", 0.5)) |
| if lambda_anchor > 0: |
| anc = anchor_loss(latent_pred, sparse_depth_norm, sparse_mask) |
| loss = loss + lambda_anchor * anc |
|
|
| |
| if not self.config.get("pretrain", False): |
| grad = multi_scale_grad_loss( |
| latent_pred.squeeze(1), latent.squeeze(1), mask.float().squeeze(1) |
| ) |
| loss = loss + 0.2 * grad |
|
|
| return { |
| "loss": loss, |
| "depth": latent_pred + 0.5, |
| "image": batch["image"], |
| "sparse_mask": sparse_mask, |
| } |
|
|
| |
| @torch.no_grad() |
| def forward_test(self, batch: dict) -> dict: |
| ori_h, ori_w = batch["image"].shape[-2:] |
| target_area = 1024 * 768 if not self.config.get("pretrain", False) else 512 * 512 |
| scale = (target_area / (ori_w * ori_h)) ** 0.5 |
| new_h = max(16, int(round(ori_h * scale / 16)) * 16) |
| new_w = max(16, int(round(ori_w * scale / 16)) * 16) |
| image = F.interpolate(batch["image"], size=(new_h, new_w), mode="bilinear", align_corners=False) |
|
|
| cond = self.get_cond(image) |
| semantics = self.semantics_prompt(image) |
|
|
| |
| if "sparse_depth" in batch and "sparse_mask" in batch: |
| sparse_depth_metric = F.interpolate( |
| batch["sparse_depth"], size=(new_h, new_w), mode="nearest" |
| ) |
| sparse_mask = F.interpolate( |
| batch["sparse_mask"].float(), size=(new_h, new_w), mode="nearest" |
| ).bool() |
| elif "depth" in batch and "mask" in batch: |
| depth_resized = F.interpolate( |
| batch["depth"], size=(new_h, new_w), mode="nearest" |
| ) |
| mask_resized = F.interpolate( |
| batch["mask"].float(), size=(new_h, new_w), mode="nearest" |
| ).bool() |
| sparse_depth_metric, sparse_mask = self.simulate_sparse_for_batch( |
| depth_resized, mask_resized |
| ) |
| else: |
| sparse_depth_metric = torch.zeros(image.shape[0], 1, new_h, new_w, device=image.device) |
| sparse_mask = torch.zeros_like(sparse_depth_metric, dtype=torch.bool) |
|
|
| |
| |
| log_d = torch.log(sparse_depth_metric.clamp_min(0.0) + 1.0) |
| B = image.shape[0] |
| log_min, log_max = [], [] |
| for i in range(B): |
| m = sparse_mask[i].bool() |
| if m.sum() == 0: |
| log_min.append(torch.tensor(0.0, device=image.device)) |
| log_max.append(torch.tensor(1.0, device=image.device)) |
| continue |
| vals = log_d[i][m] |
| log_min.append(torch.quantile(vals, 0.02)) |
| log_max.append(torch.quantile(vals, 0.98)) |
| log_min_t = torch.stack(log_min)[:, None, None, None] |
| log_max_t = torch.stack(log_max)[:, None, None, None] |
| invalid = (log_max_t - log_min_t) < 1e-6 |
| log_max_t = torch.where(invalid, log_min_t + 1e-6, log_max_t) |
| sparse_depth_norm = self.normalize_sparse( |
| sparse_depth_metric, sparse_mask, log_min_t, log_max_t |
| ) |
|
|
| |
| x_T = torch.randn(B, 1, new_h, new_w, device=image.device) |
|
|
| def predict_x0(x_tau: torch.Tensor, tau: torch.Tensor) -> torch.Tensor: |
| inp = torch.cat([x_tau, cond], dim=1) |
| v_pred = self.dit( |
| x=inp, |
| semantics=semantics, |
| timestep=tau, |
| sparse_depth=sparse_depth_norm, |
| sparse_mask=sparse_mask, |
| ) |
| x0, _ = self.schedule.convert_from_pred( |
| pred=v_pred, pred_type="velocity", x_t=x_tau, t=tau |
| ) |
| return x0 |
|
|
| kfg = KalmanInLoopConfig( |
| R_proj=self.config.get("R_proj", 0.1), |
| proj_alpha=self.config.get("proj_alpha", 0.1), |
| init_P=self.config.get("init_P", 1.0), |
| ) |
| latent, P_final = kalman_in_loop_sample( |
| dit_predict_x0=predict_x0, |
| sampler=self.sampler, |
| timesteps=list(self.sampling_timesteps), |
| x_T=x_T, |
| cond=cond, |
| semantics_fn=lambda: semantics, |
| sparse_depth=sparse_depth_norm, |
| sparse_mask=sparse_mask, |
| mu_temporal=batch.get("kalman_mu_prior"), |
| P_temporal=batch.get("kalman_P_prior"), |
| config=kfg, |
| ) |
|
|
| depth = latent + 0.5 |
| depth = F.interpolate(depth, size=(ori_h, ori_w), mode="nearest") |
| return { |
| "depth": depth, |
| "image": batch["image"], |
| "kalman_variance": P_final, |
| } |
|
|