| """ |
| Per-pixel temporal Kalman filter for video — paper §3.4. |
| |
| State at each pixel: log-depth scalar with variance. |
| Process model: warp (μ, P) along optical flow, add per-pixel process noise |
| derived from forward-backward flow consistency. Pixels whose forward-backward |
| error exceeds a threshold are flagged as occluded and have their variance |
| reset (= effectively forgetting stale geometry). |
| |
| Update model: where sparse LiDAR is observed, scalar Kalman update against the |
| measurement with variance R. Elsewhere, the prior passes through. |
| |
| Returned variance is the calibrated per-pixel uncertainty map (paper §3.4 |
| "uncertainty output"). Convert to metric uncertainty via exp(sqrt(P)) - 1. |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class TemporalKalmanConfig: |
| R: float = 0.01 |
| Q_base: float = 0.005 |
| alpha: float = 0.5 |
| P_max: float = 10.0 |
| P_init: float = 1.0 |
| occ_threshold: float = 2.0 |
|
|
|
|
| def _backward_warp(x: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: |
| """Backward warp `x` (B,C,H,W) by `flow` (B,2,H,W) — flow is sampled at the |
| target grid and tells where to fetch from in the source. |
| """ |
| B, _, H, W = x.shape |
| yy, xx = torch.meshgrid( |
| torch.arange(H, device=x.device, dtype=x.dtype), |
| torch.arange(W, device=x.device, dtype=x.dtype), |
| indexing="ij", |
| ) |
| grid_x = xx[None, None] + flow[:, 0:1] |
| grid_y = yy[None, None] + flow[:, 1:2] |
| |
| grid_x = 2.0 * grid_x / max(W - 1, 1) - 1.0 |
| grid_y = 2.0 * grid_y / max(H - 1, 1) - 1.0 |
| grid = torch.cat([grid_x, grid_y], dim=1).permute(0, 2, 3, 1) |
| return F.grid_sample(x, grid, mode="bilinear", padding_mode="border", align_corners=True) |
|
|
|
|
| def forward_backward_error( |
| flow_fwd: torch.Tensor, flow_bwd: torch.Tensor |
| ) -> torch.Tensor: |
| """ε(p) = || p + f_fwd(p) + f_bwd(p + f_fwd(p)) || (paper Eq. 6) |
| |
| Returns per-pixel error (B,1,H,W) in pixels. |
| """ |
| bwd_at_fwd = _backward_warp(flow_bwd, flow_fwd) |
| err = flow_fwd + bwd_at_fwd |
| return err.norm(dim=1, keepdim=True) |
|
|
|
|
| class TemporalKalmanFilter: |
| """Stateful per-pixel Kalman over a video sequence. |
| |
| Usage: |
| kf = TemporalKalmanFilter(shape=(B, 1, H, W), device=...) |
| for k, frame in enumerate(frames): |
| if k > 0: |
| kf.predict(flow_prev_to_curr, flow_curr_to_prev) |
| mu, P = kf.update(sparse_depth_k, sparse_mask_k) |
| """ |
|
|
| def __init__( |
| self, |
| shape: tuple[int, int, int, int], |
| device: torch.device, |
| config: TemporalKalmanConfig = TemporalKalmanConfig(), |
| ): |
| self.config = config |
| self.mu = torch.zeros(shape, device=device) |
| self.P = torch.full(shape, config.P_init, device=device) |
| self.has_state = False |
|
|
| def reset(self) -> None: |
| self.mu.zero_() |
| self.P.fill_(self.config.P_init) |
| self.has_state = False |
|
|
| def predict(self, flow_fwd: torch.Tensor, flow_bwd: torch.Tensor) -> None: |
| """Predict step: warp state + inflate variance by Q + reset occluded pixels.""" |
| if not self.has_state: |
| return |
| cfg = self.config |
| |
| self.mu = _backward_warp(self.mu, flow_fwd) |
| self.P = _backward_warp(self.P, flow_fwd) |
| |
| eps = forward_backward_error(flow_fwd, flow_bwd) |
| Q = cfg.Q_base + cfg.alpha * (eps ** 2) |
| self.P = self.P + Q |
| |
| occ = eps > cfg.occ_threshold |
| self.P = torch.where(occ, torch.full_like(self.P, cfg.P_max), self.P) |
|
|
| def update( |
| self, |
| sparse_depth: torch.Tensor, |
| sparse_mask: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Measurement update at observed pixels.""" |
| cfg = self.config |
| m = sparse_mask.float() |
| K = self.P / (self.P + cfg.R) |
| self.mu = self.mu + m * K * (sparse_depth - self.mu) |
| self.P = self.P * (1.0 - m * K) |
| self.has_state = True |
| return self.mu, self.P |
|
|
| def absorb_measurement( |
| self, mu_meas: torch.Tensor, P_meas: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Absorb a *dense* measurement (e.g. the within-denoising posterior) |
| into the temporal state. Useful for chaining a Kalman-in-loop result |
| into the next frame. |
| """ |
| if not self.has_state: |
| self.mu = mu_meas.clone() |
| self.P = P_meas.clone() |
| self.has_state = True |
| return self.mu, self.P |
| K = self.P / (self.P + P_meas) |
| self.mu = self.mu + K * (mu_meas - self.mu) |
| self.P = (1.0 - K) * self.P |
| return self.mu, self.P |
|
|
| def metric_uncertainty(self) -> torch.Tensor: |
| """Std-dev in log-depth → (multiplicative) metric uncertainty. |
| |
| Following paper §3.4: σ_metric ≈ exp(sqrt(P)) - 1. |
| """ |
| return torch.exp(self.P.clamp_min(0.0).sqrt()) - 1.0 |
|
|