File size: 5,375 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Per-pixel temporal Kalman filter for video — paper §3.4.

State at each pixel: log-depth scalar with variance.
Process model: warp (μ, P) along optical flow, add per-pixel process noise
derived from forward-backward flow consistency. Pixels whose forward-backward
error exceeds a threshold are flagged as occluded and have their variance
reset (= effectively forgetting stale geometry).

Update model: where sparse LiDAR is observed, scalar Kalman update against the
measurement with variance R. Elsewhere, the prior passes through.

Returned variance is the calibrated per-pixel uncertainty map (paper §3.4
"uncertainty output"). Convert to metric uncertainty via exp(sqrt(P)) - 1.
"""
from __future__ import annotations

from dataclasses import dataclass

import torch
import torch.nn.functional as F


@dataclass
class TemporalKalmanConfig:
    R: float = 0.01           # measurement variance (sparse LiDAR)
    Q_base: float = 0.005     # process-noise floor
    alpha: float = 0.5        # process-noise scaling on flow consistency
    P_max: float = 10.0       # variance reset on occlusion
    P_init: float = 1.0       # initial variance
    occ_threshold: float = 2.0  # forward-backward error threshold (pixels)


def _backward_warp(x: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
    """Backward warp `x` (B,C,H,W) by `flow` (B,2,H,W) — flow is sampled at the
    target grid and tells where to fetch from in the source.
    """
    B, _, H, W = x.shape
    yy, xx = torch.meshgrid(
        torch.arange(H, device=x.device, dtype=x.dtype),
        torch.arange(W, device=x.device, dtype=x.dtype),
        indexing="ij",
    )
    grid_x = xx[None, None] + flow[:, 0:1]
    grid_y = yy[None, None] + flow[:, 1:2]
    # normalize to [-1, 1]
    grid_x = 2.0 * grid_x / max(W - 1, 1) - 1.0
    grid_y = 2.0 * grid_y / max(H - 1, 1) - 1.0
    grid = torch.cat([grid_x, grid_y], dim=1).permute(0, 2, 3, 1)
    return F.grid_sample(x, grid, mode="bilinear", padding_mode="border", align_corners=True)


def forward_backward_error(
    flow_fwd: torch.Tensor, flow_bwd: torch.Tensor
) -> torch.Tensor:
    """ε(p) = || p + f_fwd(p) + f_bwd(p + f_fwd(p)) ||  (paper Eq. 6)

    Returns per-pixel error (B,1,H,W) in pixels.
    """
    bwd_at_fwd = _backward_warp(flow_bwd, flow_fwd)
    err = flow_fwd + bwd_at_fwd
    return err.norm(dim=1, keepdim=True)


class TemporalKalmanFilter:
    """Stateful per-pixel Kalman over a video sequence.

    Usage:
        kf = TemporalKalmanFilter(shape=(B, 1, H, W), device=...)
        for k, frame in enumerate(frames):
            if k > 0:
                kf.predict(flow_prev_to_curr, flow_curr_to_prev)
            mu, P = kf.update(sparse_depth_k, sparse_mask_k)
    """

    def __init__(
        self,
        shape: tuple[int, int, int, int],
        device: torch.device,
        config: TemporalKalmanConfig = TemporalKalmanConfig(),
    ):
        self.config = config
        self.mu = torch.zeros(shape, device=device)
        self.P = torch.full(shape, config.P_init, device=device)
        self.has_state = False  # set after first update

    def reset(self) -> None:
        self.mu.zero_()
        self.P.fill_(self.config.P_init)
        self.has_state = False

    def predict(self, flow_fwd: torch.Tensor, flow_bwd: torch.Tensor) -> None:
        """Predict step: warp state + inflate variance by Q + reset occluded pixels."""
        if not self.has_state:
            return
        cfg = self.config
        # 1) warp state along flow
        self.mu = _backward_warp(self.mu, flow_fwd)
        self.P = _backward_warp(self.P, flow_fwd)
        # 2) compute fb error and per-pixel Q
        eps = forward_backward_error(flow_fwd, flow_bwd)
        Q = cfg.Q_base + cfg.alpha * (eps ** 2)
        self.P = self.P + Q
        # 3) occlusion: reset variance to P_max (mu effectively floats free)
        occ = eps > cfg.occ_threshold
        self.P = torch.where(occ, torch.full_like(self.P, cfg.P_max), self.P)

    def update(
        self,
        sparse_depth: torch.Tensor,
        sparse_mask: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Measurement update at observed pixels."""
        cfg = self.config
        m = sparse_mask.float()
        K = self.P / (self.P + cfg.R)
        self.mu = self.mu + m * K * (sparse_depth - self.mu)
        self.P = self.P * (1.0 - m * K)
        self.has_state = True
        return self.mu, self.P

    def absorb_measurement(
        self, mu_meas: torch.Tensor, P_meas: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Absorb a *dense* measurement (e.g. the within-denoising posterior)
        into the temporal state. Useful for chaining a Kalman-in-loop result
        into the next frame.
        """
        if not self.has_state:
            self.mu = mu_meas.clone()
            self.P = P_meas.clone()
            self.has_state = True
            return self.mu, self.P
        K = self.P / (self.P + P_meas)
        self.mu = self.mu + K * (mu_meas - self.mu)
        self.P = (1.0 - K) * self.P
        return self.mu, self.P

    def metric_uncertainty(self) -> torch.Tensor:
        """Std-dev in log-depth → (multiplicative) metric uncertainty.

        Following paper §3.4: σ_metric ≈ exp(sqrt(P)) - 1.
        """
        return torch.exp(self.P.clamp_min(0.0).sqrt()) - 1.0