""" Posterior score projection — Eq. 5 in paper. Given the current denoising latent x_τ at noise level σ_τ, the projection applies the LiDAR likelihood and the Kalman temporal prior as additive score corrections: x_τ ← x_τ + η_τ * [ -M ⊙ (x_τ - y) / R - (x_τ - μ_prior) / P_prior ] with η_τ = α · σ_τ². RePaint corresponds to keeping only the LiDAR term with η_τ = 1 (hard projection); DPS keeps both terms but evaluates the LiDAR gradient through the score model. Our generalization uses both terms with a schedule-adapted step size. All inputs must already live in the same normalized depth space the DiT predicts in (PPD's [-0.5, 0.5] log-quantile space). """ from __future__ import annotations import torch @torch.no_grad() def posterior_project( x_t: torch.Tensor, sigma_t: torch.Tensor, *, sparse_depth: torch.Tensor, sparse_mask: torch.Tensor, R: float, mu_prior: torch.Tensor | None, P_prior: torch.Tensor | None, alpha: float = 1.0, ) -> torch.Tensor: """One projection step. Parameters ---------- x_t : (B,1,H,W) — current latent (in normalized log-depth space). sigma_t : scalar tensor — current noise level (paper uses τ/T). sparse_depth, sparse_mask : (B,1,H,W) — observations (already normalized). R : measurement noise variance (in same units as x). mu_prior : (B,1,H,W) or None — Kalman prior mean. P_prior : (B,1,H,W) or None — Kalman prior variance. alpha : scaling on the schedule-adapted step size. """ eta = alpha * (sigma_t ** 2) # LiDAR term: only at observed pixels lidar_term = -sparse_mask.float() * (x_t - sparse_depth) / R # Kalman term: only when we have a prior (otherwise no temporal info) if mu_prior is not None and P_prior is not None: kalman_term = -(x_t - mu_prior) / P_prior.clamp_min(1e-6) else: kalman_term = torch.zeros_like(x_t) return x_t + eta * (lidar_term + kalman_term)