""" Kalman-in-the-loop denoising sampler — Algorithm 1 in paper. Given a DiT that predicts x_0 (or velocity convertible to x_0) from x_τ at each denoising step, run a per-pixel Kalman state estimator across sampling steps. The state accumulates evidence: early steps trust the temporal prior (large σ² ≈ large measurement variance), late steps shift toward the DiT prediction. Variance is monotonically decreasing through denoising. The final posterior variance is returned alongside the depth — `temporal_kalman.py` consumes it as the next frame's prior. """ from __future__ import annotations from dataclasses import dataclass import torch from ppd.utils.diffusion.sampler import EulerSampler from ppd.lpd.posterior_projection import posterior_project @dataclass class KalmanInLoopConfig: R_proj: float = 0.1 # projection-step LiDAR variance (paper §4.4) proj_alpha: float = 0.1 # step-size scale α for the posterior projection measurement_floor: float = 1e-3 # variance floor for the per-step measurement init_P: float = 1.0 # default prior variance when no temporal info @torch.no_grad() def kalman_in_loop_sample( *, dit_predict_x0, sampler: EulerSampler, timesteps: torch.Tensor, x_T: torch.Tensor, cond: torch.Tensor, semantics_fn, sparse_depth: torch.Tensor, sparse_mask: torch.Tensor, mu_temporal: torch.Tensor | None = None, P_temporal: torch.Tensor | None = None, config: KalmanInLoopConfig = KalmanInLoopConfig(), ) -> tuple[torch.Tensor, torch.Tensor]: """Run sampling with Kalman-in-the-loop denoising. Parameters ---------- dit_predict_x0 : callable Closure that takes (x_tau, tau) and returns predicted clean x_0 (already in [-0.5,0.5] normalized space). sampler : an EulerSampler instance for the diffusion ODE step. timesteps : iterable of timesteps (descending — same as PPD). x_T : initial noise (B,1,H,W). cond, semantics_fn, sparse_depth, sparse_mask: Conditioning passed through; not used directly in this routine (the closure handles them). mu_temporal, P_temporal: Optional temporal Kalman prior carried over from the previous frame. """ B, _, H, W = x_T.shape device = x_T.device T = sampler.schedule.T # initialize Kalman state if mu_temporal is None: mu = torch.zeros(B, 1, H, W, device=device) else: mu = mu_temporal.clone() if P_temporal is None: P = torch.full((B, 1, H, W), config.init_P, device=device) else: P = P_temporal.clone() x = x_T for tau in timesteps: sigma_t = (tau.float() / T).clamp_min(config.measurement_floor) x_hat0 = dit_predict_x0(x, tau) # Kalman update treating DiT prediction as a measurement with σ_τ² meas_var = sigma_t ** 2 K = P / (P + meas_var) mu = mu + K * (x_hat0 - mu) P = (1.0 - K) * P # Standard diffusion step using DiT's velocity (we synthesize it from x_0) # The sampler expects the model's "pred" — for v-prediction this is v=x_T - x_0. # We approximate with the implied velocity at the current state. # Equivalent to step_to(t, get_next_timestep(t)) using the schedule. # Construct velocity from x_hat0 and current x: # x_t = (1-t/T)*x_0 + (t/T)*x_T -> x_T = (x_t - (1-t/T) x_0) * T/t t_norm = (tau.float() / T).view(1, 1, 1, 1) denom = t_norm.clamp_min(1e-6) x_T_implied = (x - (1 - t_norm) * x_hat0) / denom v_pred = x_T_implied - x_hat0 x = sampler.step(pred=v_pred, x_t=x, t=tau) # Posterior projection (LiDAR + Kalman) on the post-step latent x = posterior_project( x, sigma_t, sparse_depth=sparse_depth, sparse_mask=sparse_mask, R=config.R_proj, mu_prior=mu, P_prior=P, alpha=config.proj_alpha, ) return x, P