File size: 2,023 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Posterior score projection — Eq. 5 in paper.

Given the current denoising latent x_τ at noise level σ_τ, the projection
applies the LiDAR likelihood and the Kalman temporal prior as additive score
corrections:

    x_τ ← x_τ + η_τ * [ -M ⊙ (x_τ - y) / R   -   (x_τ - μ_prior) / P_prior ]

with η_τ = α · σ_τ². RePaint corresponds to keeping only the LiDAR term with
η_τ = 1 (hard projection); DPS keeps both terms but evaluates the LiDAR
gradient through the score model. Our generalization uses both terms with a
schedule-adapted step size.

All inputs must already live in the same normalized depth space the DiT
predicts in (PPD's [-0.5, 0.5] log-quantile space).
"""
from __future__ import annotations

import torch


@torch.no_grad()
def posterior_project(
    x_t: torch.Tensor,
    sigma_t: torch.Tensor,
    *,
    sparse_depth: torch.Tensor,
    sparse_mask: torch.Tensor,
    R: float,
    mu_prior: torch.Tensor | None,
    P_prior: torch.Tensor | None,
    alpha: float = 1.0,
) -> torch.Tensor:
    """One projection step.

    Parameters
    ----------
    x_t        : (B,1,H,W) — current latent (in normalized log-depth space).
    sigma_t    : scalar tensor — current noise level (paper uses τ/T).
    sparse_depth, sparse_mask : (B,1,H,W) — observations (already normalized).
    R          : measurement noise variance (in same units as x).
    mu_prior   : (B,1,H,W) or None — Kalman prior mean.
    P_prior    : (B,1,H,W) or None — Kalman prior variance.
    alpha      : scaling on the schedule-adapted step size.
    """
    eta = alpha * (sigma_t ** 2)
    # LiDAR term: only at observed pixels
    lidar_term = -sparse_mask.float() * (x_t - sparse_depth) / R

    # Kalman term: only when we have a prior (otherwise no temporal info)
    if mu_prior is not None and P_prior is not None:
        kalman_term = -(x_t - mu_prior) / P_prior.clamp_min(1e-6)
    else:
        kalman_term = torch.zeros_like(x_t)

    return x_t + eta * (lidar_term + kalman_term)