LiDAR-Perfect-Depth / code /ppd /lpd /temporal_kalman.py
chenming-wu's picture
code
436b829 verified
"""
Per-pixel temporal Kalman filter for video — paper §3.4.
State at each pixel: log-depth scalar with variance.
Process model: warp (μ, P) along optical flow, add per-pixel process noise
derived from forward-backward flow consistency. Pixels whose forward-backward
error exceeds a threshold are flagged as occluded and have their variance
reset (= effectively forgetting stale geometry).
Update model: where sparse LiDAR is observed, scalar Kalman update against the
measurement with variance R. Elsewhere, the prior passes through.
Returned variance is the calibrated per-pixel uncertainty map (paper §3.4
"uncertainty output"). Convert to metric uncertainty via exp(sqrt(P)) - 1.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn.functional as F
@dataclass
class TemporalKalmanConfig:
R: float = 0.01 # measurement variance (sparse LiDAR)
Q_base: float = 0.005 # process-noise floor
alpha: float = 0.5 # process-noise scaling on flow consistency
P_max: float = 10.0 # variance reset on occlusion
P_init: float = 1.0 # initial variance
occ_threshold: float = 2.0 # forward-backward error threshold (pixels)
def _backward_warp(x: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
"""Backward warp `x` (B,C,H,W) by `flow` (B,2,H,W) — flow is sampled at the
target grid and tells where to fetch from in the source.
"""
B, _, H, W = x.shape
yy, xx = torch.meshgrid(
torch.arange(H, device=x.device, dtype=x.dtype),
torch.arange(W, device=x.device, dtype=x.dtype),
indexing="ij",
)
grid_x = xx[None, None] + flow[:, 0:1]
grid_y = yy[None, None] + flow[:, 1:2]
# normalize to [-1, 1]
grid_x = 2.0 * grid_x / max(W - 1, 1) - 1.0
grid_y = 2.0 * grid_y / max(H - 1, 1) - 1.0
grid = torch.cat([grid_x, grid_y], dim=1).permute(0, 2, 3, 1)
return F.grid_sample(x, grid, mode="bilinear", padding_mode="border", align_corners=True)
def forward_backward_error(
flow_fwd: torch.Tensor, flow_bwd: torch.Tensor
) -> torch.Tensor:
"""ε(p) = || p + f_fwd(p) + f_bwd(p + f_fwd(p)) || (paper Eq. 6)
Returns per-pixel error (B,1,H,W) in pixels.
"""
bwd_at_fwd = _backward_warp(flow_bwd, flow_fwd)
err = flow_fwd + bwd_at_fwd
return err.norm(dim=1, keepdim=True)
class TemporalKalmanFilter:
"""Stateful per-pixel Kalman over a video sequence.
Usage:
kf = TemporalKalmanFilter(shape=(B, 1, H, W), device=...)
for k, frame in enumerate(frames):
if k > 0:
kf.predict(flow_prev_to_curr, flow_curr_to_prev)
mu, P = kf.update(sparse_depth_k, sparse_mask_k)
"""
def __init__(
self,
shape: tuple[int, int, int, int],
device: torch.device,
config: TemporalKalmanConfig = TemporalKalmanConfig(),
):
self.config = config
self.mu = torch.zeros(shape, device=device)
self.P = torch.full(shape, config.P_init, device=device)
self.has_state = False # set after first update
def reset(self) -> None:
self.mu.zero_()
self.P.fill_(self.config.P_init)
self.has_state = False
def predict(self, flow_fwd: torch.Tensor, flow_bwd: torch.Tensor) -> None:
"""Predict step: warp state + inflate variance by Q + reset occluded pixels."""
if not self.has_state:
return
cfg = self.config
# 1) warp state along flow
self.mu = _backward_warp(self.mu, flow_fwd)
self.P = _backward_warp(self.P, flow_fwd)
# 2) compute fb error and per-pixel Q
eps = forward_backward_error(flow_fwd, flow_bwd)
Q = cfg.Q_base + cfg.alpha * (eps ** 2)
self.P = self.P + Q
# 3) occlusion: reset variance to P_max (mu effectively floats free)
occ = eps > cfg.occ_threshold
self.P = torch.where(occ, torch.full_like(self.P, cfg.P_max), self.P)
def update(
self,
sparse_depth: torch.Tensor,
sparse_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Measurement update at observed pixels."""
cfg = self.config
m = sparse_mask.float()
K = self.P / (self.P + cfg.R)
self.mu = self.mu + m * K * (sparse_depth - self.mu)
self.P = self.P * (1.0 - m * K)
self.has_state = True
return self.mu, self.P
def absorb_measurement(
self, mu_meas: torch.Tensor, P_meas: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Absorb a *dense* measurement (e.g. the within-denoising posterior)
into the temporal state. Useful for chaining a Kalman-in-loop result
into the next frame.
"""
if not self.has_state:
self.mu = mu_meas.clone()
self.P = P_meas.clone()
self.has_state = True
return self.mu, self.P
K = self.P / (self.P + P_meas)
self.mu = self.mu + K * (mu_meas - self.mu)
self.P = (1.0 - K) * self.P
return self.mu, self.P
def metric_uncertainty(self) -> torch.Tensor:
"""Std-dev in log-depth → (multiplicative) metric uncertainty.
Following paper §3.4: σ_metric ≈ exp(sqrt(P)) - 1.
"""
return torch.exp(self.P.clamp_min(0.0).sqrt()) - 1.0