File size: 5,375 Bytes
436b829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """
Per-pixel temporal Kalman filter for video — paper §3.4.
State at each pixel: log-depth scalar with variance.
Process model: warp (μ, P) along optical flow, add per-pixel process noise
derived from forward-backward flow consistency. Pixels whose forward-backward
error exceeds a threshold are flagged as occluded and have their variance
reset (= effectively forgetting stale geometry).
Update model: where sparse LiDAR is observed, scalar Kalman update against the
measurement with variance R. Elsewhere, the prior passes through.
Returned variance is the calibrated per-pixel uncertainty map (paper §3.4
"uncertainty output"). Convert to metric uncertainty via exp(sqrt(P)) - 1.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn.functional as F
@dataclass
class TemporalKalmanConfig:
R: float = 0.01 # measurement variance (sparse LiDAR)
Q_base: float = 0.005 # process-noise floor
alpha: float = 0.5 # process-noise scaling on flow consistency
P_max: float = 10.0 # variance reset on occlusion
P_init: float = 1.0 # initial variance
occ_threshold: float = 2.0 # forward-backward error threshold (pixels)
def _backward_warp(x: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
"""Backward warp `x` (B,C,H,W) by `flow` (B,2,H,W) — flow is sampled at the
target grid and tells where to fetch from in the source.
"""
B, _, H, W = x.shape
yy, xx = torch.meshgrid(
torch.arange(H, device=x.device, dtype=x.dtype),
torch.arange(W, device=x.device, dtype=x.dtype),
indexing="ij",
)
grid_x = xx[None, None] + flow[:, 0:1]
grid_y = yy[None, None] + flow[:, 1:2]
# normalize to [-1, 1]
grid_x = 2.0 * grid_x / max(W - 1, 1) - 1.0
grid_y = 2.0 * grid_y / max(H - 1, 1) - 1.0
grid = torch.cat([grid_x, grid_y], dim=1).permute(0, 2, 3, 1)
return F.grid_sample(x, grid, mode="bilinear", padding_mode="border", align_corners=True)
def forward_backward_error(
flow_fwd: torch.Tensor, flow_bwd: torch.Tensor
) -> torch.Tensor:
"""ε(p) = || p + f_fwd(p) + f_bwd(p + f_fwd(p)) || (paper Eq. 6)
Returns per-pixel error (B,1,H,W) in pixels.
"""
bwd_at_fwd = _backward_warp(flow_bwd, flow_fwd)
err = flow_fwd + bwd_at_fwd
return err.norm(dim=1, keepdim=True)
class TemporalKalmanFilter:
"""Stateful per-pixel Kalman over a video sequence.
Usage:
kf = TemporalKalmanFilter(shape=(B, 1, H, W), device=...)
for k, frame in enumerate(frames):
if k > 0:
kf.predict(flow_prev_to_curr, flow_curr_to_prev)
mu, P = kf.update(sparse_depth_k, sparse_mask_k)
"""
def __init__(
self,
shape: tuple[int, int, int, int],
device: torch.device,
config: TemporalKalmanConfig = TemporalKalmanConfig(),
):
self.config = config
self.mu = torch.zeros(shape, device=device)
self.P = torch.full(shape, config.P_init, device=device)
self.has_state = False # set after first update
def reset(self) -> None:
self.mu.zero_()
self.P.fill_(self.config.P_init)
self.has_state = False
def predict(self, flow_fwd: torch.Tensor, flow_bwd: torch.Tensor) -> None:
"""Predict step: warp state + inflate variance by Q + reset occluded pixels."""
if not self.has_state:
return
cfg = self.config
# 1) warp state along flow
self.mu = _backward_warp(self.mu, flow_fwd)
self.P = _backward_warp(self.P, flow_fwd)
# 2) compute fb error and per-pixel Q
eps = forward_backward_error(flow_fwd, flow_bwd)
Q = cfg.Q_base + cfg.alpha * (eps ** 2)
self.P = self.P + Q
# 3) occlusion: reset variance to P_max (mu effectively floats free)
occ = eps > cfg.occ_threshold
self.P = torch.where(occ, torch.full_like(self.P, cfg.P_max), self.P)
def update(
self,
sparse_depth: torch.Tensor,
sparse_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Measurement update at observed pixels."""
cfg = self.config
m = sparse_mask.float()
K = self.P / (self.P + cfg.R)
self.mu = self.mu + m * K * (sparse_depth - self.mu)
self.P = self.P * (1.0 - m * K)
self.has_state = True
return self.mu, self.P
def absorb_measurement(
self, mu_meas: torch.Tensor, P_meas: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Absorb a *dense* measurement (e.g. the within-denoising posterior)
into the temporal state. Useful for chaining a Kalman-in-loop result
into the next frame.
"""
if not self.has_state:
self.mu = mu_meas.clone()
self.P = P_meas.clone()
self.has_state = True
return self.mu, self.P
K = self.P / (self.P + P_meas)
self.mu = self.mu + K * (mu_meas - self.mu)
self.P = (1.0 - K) * self.P
return self.mu, self.P
def metric_uncertainty(self) -> torch.Tensor:
"""Std-dev in log-depth → (multiplicative) metric uncertainty.
Following paper §3.4: σ_metric ≈ exp(sqrt(P)) - 1.
"""
return torch.exp(self.P.clamp_min(0.0).sqrt()) - 1.0
|