| """ |
| Kalman-in-the-loop denoising sampler — Algorithm 1 in paper. |
| |
| Given a DiT that predicts x_0 (or velocity convertible to x_0) from x_τ at each |
| denoising step, run a per-pixel Kalman state estimator across sampling steps. |
| The state accumulates evidence: early steps trust the temporal prior (large σ² |
| ≈ large measurement variance), late steps shift toward the DiT prediction. |
| |
| Variance is monotonically decreasing through denoising. The final posterior |
| variance is returned alongside the depth — `temporal_kalman.py` consumes it as |
| the next frame's prior. |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
|
|
| from ppd.utils.diffusion.sampler import EulerSampler |
| from ppd.lpd.posterior_projection import posterior_project |
|
|
|
|
| @dataclass |
| class KalmanInLoopConfig: |
| R_proj: float = 0.1 |
| proj_alpha: float = 0.1 |
| measurement_floor: float = 1e-3 |
| init_P: float = 1.0 |
|
|
|
|
| @torch.no_grad() |
| def kalman_in_loop_sample( |
| *, |
| dit_predict_x0, |
| sampler: EulerSampler, |
| timesteps: torch.Tensor, |
| x_T: torch.Tensor, |
| cond: torch.Tensor, |
| semantics_fn, |
| sparse_depth: torch.Tensor, |
| sparse_mask: torch.Tensor, |
| mu_temporal: torch.Tensor | None = None, |
| P_temporal: torch.Tensor | None = None, |
| config: KalmanInLoopConfig = KalmanInLoopConfig(), |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Run sampling with Kalman-in-the-loop denoising. |
| |
| Parameters |
| ---------- |
| dit_predict_x0 : callable |
| Closure that takes (x_tau, tau) and returns predicted clean x_0 |
| (already in [-0.5,0.5] normalized space). |
| sampler : an EulerSampler instance for the diffusion ODE step. |
| timesteps : iterable of timesteps (descending — same as PPD). |
| x_T : initial noise (B,1,H,W). |
| cond, semantics_fn, sparse_depth, sparse_mask: |
| Conditioning passed through; not used directly in this routine |
| (the closure handles them). |
| mu_temporal, P_temporal: |
| Optional temporal Kalman prior carried over from the previous frame. |
| """ |
| B, _, H, W = x_T.shape |
| device = x_T.device |
| T = sampler.schedule.T |
|
|
| |
| if mu_temporal is None: |
| mu = torch.zeros(B, 1, H, W, device=device) |
| else: |
| mu = mu_temporal.clone() |
| if P_temporal is None: |
| P = torch.full((B, 1, H, W), config.init_P, device=device) |
| else: |
| P = P_temporal.clone() |
|
|
| x = x_T |
| for tau in timesteps: |
| sigma_t = (tau.float() / T).clamp_min(config.measurement_floor) |
| x_hat0 = dit_predict_x0(x, tau) |
|
|
| |
| meas_var = sigma_t ** 2 |
| K = P / (P + meas_var) |
| mu = mu + K * (x_hat0 - mu) |
| P = (1.0 - K) * P |
|
|
| |
| |
| |
| |
| |
| |
| t_norm = (tau.float() / T).view(1, 1, 1, 1) |
| denom = t_norm.clamp_min(1e-6) |
| x_T_implied = (x - (1 - t_norm) * x_hat0) / denom |
| v_pred = x_T_implied - x_hat0 |
| x = sampler.step(pred=v_pred, x_t=x, t=tau) |
|
|
| |
| x = posterior_project( |
| x, |
| sigma_t, |
| sparse_depth=sparse_depth, |
| sparse_mask=sparse_mask, |
| R=config.R_proj, |
| mu_prior=mu, |
| P_prior=P, |
| alpha=config.proj_alpha, |
| ) |
|
|
| return x, P |
|
|