LiDAR-Perfect-Depth / code /ppd /lpd /lpd_train.py
chenming-wu's picture
code
436b829 verified
"""
Training/inference wrapper for LiDAR-Perfect Depth.
Mirrors `ppd.models.ppd_train.PixelPerfectDepth` but:
* substitutes DiT for LPDDiT (adds sparse-prompt path)
* simulates sparse-LiDAR observations on each training batch via
`sparse_simulator.simulate`
* adds the anchor-consistency loss to the velocity-MSE + grad loss
* `forward_test` runs the Kalman-in-loop sampler with posterior projection
The DiT backbone (everything except prompt encoder + gate) can be optionally
frozen via `cfg.freeze_backbone` — paper §3.6 reports the full framework
trains fewer than 1% of parameters in this regime.
"""
from __future__ import annotations
import os
from omegaconf import DictConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from ppd.utils.diffusion.timesteps import Timesteps
from ppd.utils.diffusion.schedule import LinearSchedule
from ppd.utils.diffusion.sampler import EulerSampler
from ppd.utils.diffusion.logitnormal import LogitNormalTrainingTimesteps
from ppd.models.depth_anything_v2.dpt import DepthAnythingV2
from ppd.models.loss import multi_scale_grad_loss
from ppd.lpd.lpd_dit import LPDDiT
from ppd.lpd.sparse_simulator import simulate, random_pattern_choice
from ppd.lpd.losses import anchor_loss
from ppd.lpd.kalman_in_loop import kalman_in_loop_sample, KalmanInLoopConfig
def _device() -> torch.device:
return torch.device("cuda", int(os.environ.get("LOCAL_RANK", "0")))
class LiDARPerfectDepth(nn.Module):
"""LPD trainer/inferencer."""
def __init__(self, config: DictConfig):
super().__init__()
self.config = config
self.configure_diffusion()
# Semantics encoder (frozen, identical to PPD)
if config.semantics_model == "MoGe2":
from ppd.moge.model.v2 import MoGeModel
self.sem_encoder = MoGeModel.from_pretrained(config.semantics_pth)
else:
self.sem_encoder = DepthAnythingV2(
encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]
)
self.sem_encoder.load_state_dict(
torch.load(config.semantics_pth, map_location="cpu"), strict=False
)
self.sem_encoder = self.sem_encoder.to(_device()).eval()
self.sem_encoder.requires_grad_(False)
# LPD-DiT replaces the vanilla DiT
self.dit = LPDDiT(
in_channels=config.score_model.get("in_channels", 4),
out_channels=config.score_model.get("out_channels", 1),
hidden_size=config.score_model.get("hidden_size", 1024),
depth=config.score_model.get("depth", 24),
num_heads=config.score_model.get("num_heads", 16),
patch_size=config.score_model.get("patch_size", 8),
mlp_ratio=config.score_model.get("mlp_ratio", 4.0),
prompt_scales=tuple(config.get("prompt_scales", (4, 8, 16, 32))),
prompt_hidden=config.get("prompt_hidden", 128),
)
# Optionally load PPD-pretrained DiT weights (everything except the new
# sparse-prompt branch; load_state_dict with strict=False).
ppd_weights = config.get("ppd_weights", None)
if ppd_weights and os.path.exists(ppd_weights):
self._load_ppd_weights(ppd_weights)
if config.get("freeze_backbone", True):
self.dit.freeze_backbone()
# ------------------------------------------------------------------ setup
def _load_ppd_weights(self, path: str) -> None:
sd = torch.load(path, map_location="cpu")
if isinstance(sd, dict) and "state_dict" in sd:
sd = sd["state_dict"]
# Strip any pipeline.dit. prefix or dit. prefix
cleaned = {}
for k, v in sd.items():
for prefix in ("pipeline.dit.", "dit."):
if k.startswith(prefix):
cleaned[k[len(prefix):]] = v
break
if not cleaned:
cleaned = sd
missing, unexpected = self.dit.load_state_dict(cleaned, strict=False)
# Expect 'sparse_prompt_encoder.*' and 'prompt_gate.*' to be missing
# (they're new modules), and nothing to be unexpected.
if any(
not (k.startswith("sparse_prompt_encoder") or k.startswith("prompt_gate"))
for k in missing
):
print(f"[LPD] Missing keys when loading PPD weights: {missing[:5]}...")
if unexpected:
print(f"[LPD] Unexpected keys: {unexpected[:5]}...")
def configure_diffusion(self) -> None:
self.schedule = LinearSchedule(T=1000)
self.sampling_timesteps = Timesteps(
T=self.schedule.T,
steps=self.config.diffusion.timesteps.sampling.steps,
device=_device(),
)
self.sampler = EulerSampler(
schedule=self.schedule,
timesteps=self.sampling_timesteps,
prediction_type="velocity",
)
self.training_timesteps = LogitNormalTrainingTimesteps(
T=self.schedule.T,
loc=self.config.diffusion.timesteps.training.loc,
scale=self.config.diffusion.timesteps.training.scale,
)
# ----------------------------------------------------------------- helpers
@torch.no_grad()
def get_cond(self, img: torch.Tensor) -> torch.Tensor:
return img - 0.5
@torch.no_grad()
def semantics_prompt(self, image: torch.Tensor) -> torch.Tensor:
return self.sem_encoder.forward_semantics(image)
@torch.no_grad()
def get_gt(self, batch: dict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (latent, mask, log_min_used, log_max_used).
The min/max are returned so we can normalize the simulated sparse
observation into the same space the DiT predicts in.
"""
depth = batch["depth"]
mask = batch["mask"].bool()
B = depth.shape[0]
clip_mask = mask & (depth < 80.0)
log_depth = torch.log(depth + 1.0)
min_vals, max_vals = [], []
for i in range(B):
i_d, i_m = log_depth[i], clip_mask[i]
if i_m.sum() == 0:
min_vals.append(torch.tensor(0.0, device=depth.device))
max_vals.append(torch.tensor(1.0, device=depth.device))
continue
vals = i_d[i_m]
min_vals.append(torch.quantile(vals, 0.02))
max_vals.append(torch.quantile(vals, 0.98))
min_v = torch.stack(min_vals)[:, None, None, None]
max_v = torch.stack(max_vals)[:, None, None, None]
invalid = (max_v - min_v) < 1e-6
max_v = torch.where(invalid, min_v + 1e-6, max_v)
norm = (log_depth - min_v) / (max_v - min_v)
norm = torch.clamp(norm, -0.5, 1.0) - 0.5
return norm, mask, min_v, max_v
@torch.no_grad()
def normalize_sparse(
self,
sparse_depth: torch.Tensor,
sparse_mask: torch.Tensor,
log_min: torch.Tensor,
log_max: torch.Tensor,
) -> torch.Tensor:
"""Apply the same per-sample log-quantile normalization as `get_gt`."""
log_d = torch.log(sparse_depth + 1.0)
norm = (log_d - log_min) / (log_max - log_min) - 0.5
norm = torch.clamp(norm, -0.5, 1.0)
return norm * sparse_mask.float()
@torch.no_grad()
def simulate_sparse_for_batch(
self,
depth: torch.Tensor,
mask: torch.Tensor,
cfg: DictConfig | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Wraps `simulate` with sensible defaults from config."""
scfg = (cfg or self.config).get("sparse", {})
pattern = scfg.get("pattern", "auto")
if pattern == "auto":
pattern = random_pattern_choice()
return simulate(
depth, mask,
pattern=pattern,
density=scfg.get("density", 0.005),
n_lines=scfg.get("n_lines", 64),
line_density=scfg.get("line_density", 0.5),
grid_stride=scfg.get("grid_stride", 32),
min_points=scfg.get("min_points", 16),
measurement_noise_std=scfg.get("measurement_noise_std", 0.0),
)
# ------------------------------------------------------------------- train
def forward_train(self, batch: dict) -> dict:
B = batch["image"].shape[0]
cond = self.get_cond(batch["image"])
latent, mask, log_min, log_max = self.get_gt(batch)
semantics = self.semantics_prompt(batch["image"])
# Simulate sparse-LiDAR from dense GT, then normalize into latent space
sparse_depth_metric, sparse_mask = self.simulate_sparse_for_batch(
batch["depth"], batch["mask"].bool()
)
sparse_depth_norm = self.normalize_sparse(
sparse_depth_metric, sparse_mask, log_min, log_max
)
# Diffusion forward: noise the GT latent
noises = torch.randn_like(latent)
timesteps = self.training_timesteps.sample([B], device=_device())
latent_noised = self.schedule.forward(latent, noises, timesteps)
x = torch.cat([latent_noised, cond], dim=1)
pred = self.dit(
x=x,
semantics=semantics,
timestep=timesteps,
sparse_depth=sparse_depth_norm,
sparse_mask=sparse_mask,
)
latent_pred, noises_pred = self.schedule.convert_from_pred(
pred=pred, pred_type="velocity", x_t=latent_noised, t=timesteps
)
loss_input = self.schedule.convert_to_pred(
x_0=latent_pred, x_T=noises_pred, t=timesteps, pred_type="velocity"
)
loss_target = self.schedule.convert_to_pred(
x_0=latent, x_T=noises, t=timesteps, pred_type="velocity"
)
mse = F.mse_loss(loss_input, loss_target, reduction="none") * mask.float()
mse = mse.sum() / (mask.float().sum() + 1e-6)
loss = mse
# Anchor consistency loss in normalized space
lambda_anchor = float(self.config.get("lambda_anchor", 0.5))
if lambda_anchor > 0:
anc = anchor_loss(latent_pred, sparse_depth_norm, sparse_mask)
loss = loss + lambda_anchor * anc
# Multi-scale gradient loss (fine-tuning only, mirrors PPD)
if not self.config.get("pretrain", False):
grad = multi_scale_grad_loss(
latent_pred.squeeze(1), latent.squeeze(1), mask.float().squeeze(1)
)
loss = loss + 0.2 * grad
return {
"loss": loss,
"depth": latent_pred + 0.5,
"image": batch["image"],
"sparse_mask": sparse_mask,
}
# ------------------------------------------------------------------ infer
@torch.no_grad()
def forward_test(self, batch: dict) -> dict:
ori_h, ori_w = batch["image"].shape[-2:]
target_area = 1024 * 768 if not self.config.get("pretrain", False) else 512 * 512
scale = (target_area / (ori_w * ori_h)) ** 0.5
new_h = max(16, int(round(ori_h * scale / 16)) * 16)
new_w = max(16, int(round(ori_w * scale / 16)) * 16)
image = F.interpolate(batch["image"], size=(new_h, new_w), mode="bilinear", align_corners=False)
cond = self.get_cond(image)
semantics = self.semantics_prompt(image)
# Sparse observations: from batch if provided, else simulate from dense GT
if "sparse_depth" in batch and "sparse_mask" in batch:
sparse_depth_metric = F.interpolate(
batch["sparse_depth"], size=(new_h, new_w), mode="nearest"
)
sparse_mask = F.interpolate(
batch["sparse_mask"].float(), size=(new_h, new_w), mode="nearest"
).bool()
elif "depth" in batch and "mask" in batch:
depth_resized = F.interpolate(
batch["depth"], size=(new_h, new_w), mode="nearest"
)
mask_resized = F.interpolate(
batch["mask"].float(), size=(new_h, new_w), mode="nearest"
).bool()
sparse_depth_metric, sparse_mask = self.simulate_sparse_for_batch(
depth_resized, mask_resized
)
else:
sparse_depth_metric = torch.zeros(image.shape[0], 1, new_h, new_w, device=image.device)
sparse_mask = torch.zeros_like(sparse_depth_metric, dtype=torch.bool)
# Normalize sparse: use a coarse min/max from the sparse observations
# themselves (no GT depth available at inference).
log_d = torch.log(sparse_depth_metric.clamp_min(0.0) + 1.0)
B = image.shape[0]
log_min, log_max = [], []
for i in range(B):
m = sparse_mask[i].bool()
if m.sum() == 0:
log_min.append(torch.tensor(0.0, device=image.device))
log_max.append(torch.tensor(1.0, device=image.device))
continue
vals = log_d[i][m]
log_min.append(torch.quantile(vals, 0.02))
log_max.append(torch.quantile(vals, 0.98))
log_min_t = torch.stack(log_min)[:, None, None, None]
log_max_t = torch.stack(log_max)[:, None, None, None]
invalid = (log_max_t - log_min_t) < 1e-6
log_max_t = torch.where(invalid, log_min_t + 1e-6, log_max_t)
sparse_depth_norm = self.normalize_sparse(
sparse_depth_metric, sparse_mask, log_min_t, log_max_t
)
# Kalman-in-loop sampling
x_T = torch.randn(B, 1, new_h, new_w, device=image.device)
def predict_x0(x_tau: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:
inp = torch.cat([x_tau, cond], dim=1)
v_pred = self.dit(
x=inp,
semantics=semantics,
timestep=tau,
sparse_depth=sparse_depth_norm,
sparse_mask=sparse_mask,
)
x0, _ = self.schedule.convert_from_pred(
pred=v_pred, pred_type="velocity", x_t=x_tau, t=tau
)
return x0
kfg = KalmanInLoopConfig(
R_proj=self.config.get("R_proj", 0.1),
proj_alpha=self.config.get("proj_alpha", 0.1),
init_P=self.config.get("init_P", 1.0),
)
latent, P_final = kalman_in_loop_sample(
dit_predict_x0=predict_x0,
sampler=self.sampler,
timesteps=list(self.sampling_timesteps),
x_T=x_T,
cond=cond,
semantics_fn=lambda: semantics,
sparse_depth=sparse_depth_norm,
sparse_mask=sparse_mask,
mu_temporal=batch.get("kalman_mu_prior"),
P_temporal=batch.get("kalman_P_prior"),
config=kfg,
)
depth = latent + 0.5
depth = F.interpolate(depth, size=(ori_h, ori_w), mode="nearest")
return {
"depth": depth,
"image": batch["image"],
"kalman_variance": P_final,
}