LiDAR-Perfect-Depth / code /ppd /lpd /kalman_in_loop.py
chenming-wu's picture
code
436b829 verified
"""
Kalman-in-the-loop denoising sampler — Algorithm 1 in paper.
Given a DiT that predicts x_0 (or velocity convertible to x_0) from x_τ at each
denoising step, run a per-pixel Kalman state estimator across sampling steps.
The state accumulates evidence: early steps trust the temporal prior (large σ²
≈ large measurement variance), late steps shift toward the DiT prediction.
Variance is monotonically decreasing through denoising. The final posterior
variance is returned alongside the depth — `temporal_kalman.py` consumes it as
the next frame's prior.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from ppd.utils.diffusion.sampler import EulerSampler
from ppd.lpd.posterior_projection import posterior_project
@dataclass
class KalmanInLoopConfig:
R_proj: float = 0.1 # projection-step LiDAR variance (paper §4.4)
proj_alpha: float = 0.1 # step-size scale α for the posterior projection
measurement_floor: float = 1e-3 # variance floor for the per-step measurement
init_P: float = 1.0 # default prior variance when no temporal info
@torch.no_grad()
def kalman_in_loop_sample(
*,
dit_predict_x0,
sampler: EulerSampler,
timesteps: torch.Tensor,
x_T: torch.Tensor,
cond: torch.Tensor,
semantics_fn,
sparse_depth: torch.Tensor,
sparse_mask: torch.Tensor,
mu_temporal: torch.Tensor | None = None,
P_temporal: torch.Tensor | None = None,
config: KalmanInLoopConfig = KalmanInLoopConfig(),
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run sampling with Kalman-in-the-loop denoising.
Parameters
----------
dit_predict_x0 : callable
Closure that takes (x_tau, tau) and returns predicted clean x_0
(already in [-0.5,0.5] normalized space).
sampler : an EulerSampler instance for the diffusion ODE step.
timesteps : iterable of timesteps (descending — same as PPD).
x_T : initial noise (B,1,H,W).
cond, semantics_fn, sparse_depth, sparse_mask:
Conditioning passed through; not used directly in this routine
(the closure handles them).
mu_temporal, P_temporal:
Optional temporal Kalman prior carried over from the previous frame.
"""
B, _, H, W = x_T.shape
device = x_T.device
T = sampler.schedule.T
# initialize Kalman state
if mu_temporal is None:
mu = torch.zeros(B, 1, H, W, device=device)
else:
mu = mu_temporal.clone()
if P_temporal is None:
P = torch.full((B, 1, H, W), config.init_P, device=device)
else:
P = P_temporal.clone()
x = x_T
for tau in timesteps:
sigma_t = (tau.float() / T).clamp_min(config.measurement_floor)
x_hat0 = dit_predict_x0(x, tau)
# Kalman update treating DiT prediction as a measurement with σ_τ²
meas_var = sigma_t ** 2
K = P / (P + meas_var)
mu = mu + K * (x_hat0 - mu)
P = (1.0 - K) * P
# Standard diffusion step using DiT's velocity (we synthesize it from x_0)
# The sampler expects the model's "pred" — for v-prediction this is v=x_T - x_0.
# We approximate with the implied velocity at the current state.
# Equivalent to step_to(t, get_next_timestep(t)) using the schedule.
# Construct velocity from x_hat0 and current x:
# x_t = (1-t/T)*x_0 + (t/T)*x_T -> x_T = (x_t - (1-t/T) x_0) * T/t
t_norm = (tau.float() / T).view(1, 1, 1, 1)
denom = t_norm.clamp_min(1e-6)
x_T_implied = (x - (1 - t_norm) * x_hat0) / denom
v_pred = x_T_implied - x_hat0
x = sampler.step(pred=v_pred, x_t=x, t=tau)
# Posterior projection (LiDAR + Kalman) on the post-step latent
x = posterior_project(
x,
sigma_t,
sparse_depth=sparse_depth,
sparse_mask=sparse_mask,
R=config.R_proj,
mu_prior=mu,
P_prior=P,
alpha=config.proj_alpha,
)
return x, P