""" Per-pixel temporal Kalman filter for video — paper §3.4. State at each pixel: log-depth scalar with variance. Process model: warp (μ, P) along optical flow, add per-pixel process noise derived from forward-backward flow consistency. Pixels whose forward-backward error exceeds a threshold are flagged as occluded and have their variance reset (= effectively forgetting stale geometry). Update model: where sparse LiDAR is observed, scalar Kalman update against the measurement with variance R. Elsewhere, the prior passes through. Returned variance is the calibrated per-pixel uncertainty map (paper §3.4 "uncertainty output"). Convert to metric uncertainty via exp(sqrt(P)) - 1. """ from __future__ import annotations from dataclasses import dataclass import torch import torch.nn.functional as F @dataclass class TemporalKalmanConfig: R: float = 0.01 # measurement variance (sparse LiDAR) Q_base: float = 0.005 # process-noise floor alpha: float = 0.5 # process-noise scaling on flow consistency P_max: float = 10.0 # variance reset on occlusion P_init: float = 1.0 # initial variance occ_threshold: float = 2.0 # forward-backward error threshold (pixels) def _backward_warp(x: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: """Backward warp `x` (B,C,H,W) by `flow` (B,2,H,W) — flow is sampled at the target grid and tells where to fetch from in the source. """ B, _, H, W = x.shape yy, xx = torch.meshgrid( torch.arange(H, device=x.device, dtype=x.dtype), torch.arange(W, device=x.device, dtype=x.dtype), indexing="ij", ) grid_x = xx[None, None] + flow[:, 0:1] grid_y = yy[None, None] + flow[:, 1:2] # normalize to [-1, 1] grid_x = 2.0 * grid_x / max(W - 1, 1) - 1.0 grid_y = 2.0 * grid_y / max(H - 1, 1) - 1.0 grid = torch.cat([grid_x, grid_y], dim=1).permute(0, 2, 3, 1) return F.grid_sample(x, grid, mode="bilinear", padding_mode="border", align_corners=True) def forward_backward_error( flow_fwd: torch.Tensor, flow_bwd: torch.Tensor ) -> torch.Tensor: """ε(p) = || p + f_fwd(p) + f_bwd(p + f_fwd(p)) || (paper Eq. 6) Returns per-pixel error (B,1,H,W) in pixels. """ bwd_at_fwd = _backward_warp(flow_bwd, flow_fwd) err = flow_fwd + bwd_at_fwd return err.norm(dim=1, keepdim=True) class TemporalKalmanFilter: """Stateful per-pixel Kalman over a video sequence. Usage: kf = TemporalKalmanFilter(shape=(B, 1, H, W), device=...) for k, frame in enumerate(frames): if k > 0: kf.predict(flow_prev_to_curr, flow_curr_to_prev) mu, P = kf.update(sparse_depth_k, sparse_mask_k) """ def __init__( self, shape: tuple[int, int, int, int], device: torch.device, config: TemporalKalmanConfig = TemporalKalmanConfig(), ): self.config = config self.mu = torch.zeros(shape, device=device) self.P = torch.full(shape, config.P_init, device=device) self.has_state = False # set after first update def reset(self) -> None: self.mu.zero_() self.P.fill_(self.config.P_init) self.has_state = False def predict(self, flow_fwd: torch.Tensor, flow_bwd: torch.Tensor) -> None: """Predict step: warp state + inflate variance by Q + reset occluded pixels.""" if not self.has_state: return cfg = self.config # 1) warp state along flow self.mu = _backward_warp(self.mu, flow_fwd) self.P = _backward_warp(self.P, flow_fwd) # 2) compute fb error and per-pixel Q eps = forward_backward_error(flow_fwd, flow_bwd) Q = cfg.Q_base + cfg.alpha * (eps ** 2) self.P = self.P + Q # 3) occlusion: reset variance to P_max (mu effectively floats free) occ = eps > cfg.occ_threshold self.P = torch.where(occ, torch.full_like(self.P, cfg.P_max), self.P) def update( self, sparse_depth: torch.Tensor, sparse_mask: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Measurement update at observed pixels.""" cfg = self.config m = sparse_mask.float() K = self.P / (self.P + cfg.R) self.mu = self.mu + m * K * (sparse_depth - self.mu) self.P = self.P * (1.0 - m * K) self.has_state = True return self.mu, self.P def absorb_measurement( self, mu_meas: torch.Tensor, P_meas: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Absorb a *dense* measurement (e.g. the within-denoising posterior) into the temporal state. Useful for chaining a Kalman-in-loop result into the next frame. """ if not self.has_state: self.mu = mu_meas.clone() self.P = P_meas.clone() self.has_state = True return self.mu, self.P K = self.P / (self.P + P_meas) self.mu = self.mu + K * (mu_meas - self.mu) self.P = (1.0 - K) * self.P return self.mu, self.P def metric_uncertainty(self) -> torch.Tensor: """Std-dev in log-depth → (multiplicative) metric uncertainty. Following paper §3.4: σ_metric ≈ exp(sqrt(P)) - 1. """ return torch.exp(self.P.clamp_min(0.0).sqrt()) - 1.0