File size: 4,061 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Kalman-in-the-loop denoising sampler — Algorithm 1 in paper.

Given a DiT that predicts x_0 (or velocity convertible to x_0) from x_τ at each
denoising step, run a per-pixel Kalman state estimator across sampling steps.
The state accumulates evidence: early steps trust the temporal prior (large σ²
≈ large measurement variance), late steps shift toward the DiT prediction.

Variance is monotonically decreasing through denoising. The final posterior
variance is returned alongside the depth — `temporal_kalman.py` consumes it as
the next frame's prior.
"""
from __future__ import annotations

from dataclasses import dataclass

import torch

from ppd.utils.diffusion.sampler import EulerSampler
from ppd.lpd.posterior_projection import posterior_project


@dataclass
class KalmanInLoopConfig:
    R_proj: float = 0.1             # projection-step LiDAR variance (paper §4.4)
    proj_alpha: float = 0.1         # step-size scale α for the posterior projection
    measurement_floor: float = 1e-3 # variance floor for the per-step measurement
    init_P: float = 1.0             # default prior variance when no temporal info


@torch.no_grad()
def kalman_in_loop_sample(
    *,
    dit_predict_x0,
    sampler: EulerSampler,
    timesteps: torch.Tensor,
    x_T: torch.Tensor,
    cond: torch.Tensor,
    semantics_fn,
    sparse_depth: torch.Tensor,
    sparse_mask: torch.Tensor,
    mu_temporal: torch.Tensor | None = None,
    P_temporal: torch.Tensor | None = None,
    config: KalmanInLoopConfig = KalmanInLoopConfig(),
) -> tuple[torch.Tensor, torch.Tensor]:
    """Run sampling with Kalman-in-the-loop denoising.

    Parameters
    ----------
    dit_predict_x0 : callable
        Closure that takes (x_tau, tau) and returns predicted clean x_0
        (already in [-0.5,0.5] normalized space).
    sampler        : an EulerSampler instance for the diffusion ODE step.
    timesteps      : iterable of timesteps (descending — same as PPD).
    x_T            : initial noise (B,1,H,W).
    cond, semantics_fn, sparse_depth, sparse_mask:
        Conditioning passed through; not used directly in this routine
        (the closure handles them).
    mu_temporal, P_temporal:
        Optional temporal Kalman prior carried over from the previous frame.
    """
    B, _, H, W = x_T.shape
    device = x_T.device
    T = sampler.schedule.T

    # initialize Kalman state
    if mu_temporal is None:
        mu = torch.zeros(B, 1, H, W, device=device)
    else:
        mu = mu_temporal.clone()
    if P_temporal is None:
        P = torch.full((B, 1, H, W), config.init_P, device=device)
    else:
        P = P_temporal.clone()

    x = x_T
    for tau in timesteps:
        sigma_t = (tau.float() / T).clamp_min(config.measurement_floor)
        x_hat0 = dit_predict_x0(x, tau)

        # Kalman update treating DiT prediction as a measurement with σ_τ²
        meas_var = sigma_t ** 2
        K = P / (P + meas_var)
        mu = mu + K * (x_hat0 - mu)
        P = (1.0 - K) * P

        # Standard diffusion step using DiT's velocity (we synthesize it from x_0)
        # The sampler expects the model's "pred" — for v-prediction this is v=x_T - x_0.
        # We approximate with the implied velocity at the current state.
        # Equivalent to step_to(t, get_next_timestep(t)) using the schedule.
        # Construct velocity from x_hat0 and current x:
        # x_t = (1-t/T)*x_0 + (t/T)*x_T  ->  x_T = (x_t - (1-t/T) x_0) * T/t
        t_norm = (tau.float() / T).view(1, 1, 1, 1)
        denom = t_norm.clamp_min(1e-6)
        x_T_implied = (x - (1 - t_norm) * x_hat0) / denom
        v_pred = x_T_implied - x_hat0
        x = sampler.step(pred=v_pred, x_t=x, t=tau)

        # Posterior projection (LiDAR + Kalman) on the post-step latent
        x = posterior_project(
            x,
            sigma_t,
            sparse_depth=sparse_depth,
            sparse_mask=sparse_mask,
            R=config.R_proj,
            mu_prior=mu,
            P_prior=P,
            alpha=config.proj_alpha,
        )

    return x, P