File size: 5,101 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Video inference wrapper.

Runs LPD frame-by-frame, threading a `TemporalKalmanFilter` between frames.
Optical flow is recomputed per-frame with RAFT (torchvision) — the variance
inflation in the predict step is driven by forward-backward consistency, so
flow noise self-regulates. The Kalman state's posterior variance is fed
back as a per-token confidence to the prompt encoder via
`uncertainty_modulation.modulate_density`.

This module is inference-only (paper §3.6 — temporal mechanisms require no
extra training).
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable

import torch
import torch.nn.functional as F

from ppd.lpd.lpd_train import LiDARPerfectDepth
from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig


@dataclass
class VideoInferenceConfig:
    flow_model: str = "raft_large"          # torchvision raft_large or "none"
    raft_iters: int = 12
    flow_resize: int | None = 384           # downsample to this short side for RAFT, then upsample
    temporal: TemporalKalmanConfig = TemporalKalmanConfig()


def _load_raft(model_name: str, device: torch.device):
    """Lazy-load torchvision RAFT and return a callable (img1, img2) -> flow."""
    from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
    weights = Raft_Large_Weights.DEFAULT
    model = raft_large(weights=weights).to(device).eval()
    transforms = weights.transforms()

    def predict(img1: torch.Tensor, img2: torch.Tensor, iters: int = 12) -> torch.Tensor:
        # img1, img2: (B, 3, H, W) in [0, 1]
        img1, img2 = transforms(img1, img2)
        with torch.no_grad():
            flow = model(img1, img2, num_flow_updates=iters)[-1]
        return flow

    return predict


def _compute_flow(predict_fn, img_prev: torch.Tensor, img_curr: torch.Tensor, resize: int | None):
    """Compute forward + backward flow with optional downsampled solver."""
    H, W = img_prev.shape[-2:]
    if resize is not None and min(H, W) > resize:
        scale = resize / min(H, W)
        new_h = int(round(H * scale / 8) * 8)
        new_w = int(round(W * scale / 8) * 8)
        ip = F.interpolate(img_prev, size=(new_h, new_w), mode="bilinear", align_corners=False)
        ic = F.interpolate(img_curr, size=(new_h, new_w), mode="bilinear", align_corners=False)
        f_fwd = predict_fn(ip, ic)
        f_bwd = predict_fn(ic, ip)
        # rescale flow to original resolution
        f_fwd = F.interpolate(f_fwd, size=(H, W), mode="bilinear", align_corners=False)
        f_bwd = F.interpolate(f_bwd, size=(H, W), mode="bilinear", align_corners=False)
        f_fwd[:, 0] *= W / new_w
        f_fwd[:, 1] *= H / new_h
        f_bwd[:, 0] *= W / new_w
        f_bwd[:, 1] *= H / new_h
    else:
        f_fwd = predict_fn(img_prev, img_curr)
        f_bwd = predict_fn(img_curr, img_prev)
    return f_fwd, f_bwd


@torch.no_grad()
def run_video(
    pipeline: LiDARPerfectDepth,
    frames: Iterable[dict],
    *,
    config: VideoInferenceConfig = VideoInferenceConfig(),
) -> list[dict]:
    """Run LPD over a sequence of per-frame batches.

    `frames` yields dicts with at least `image`, optionally `sparse_depth` +
    `sparse_mask`, optionally `depth` + `mask` (for simulating sparse).
    Each dict is a (B,3,H,W) batch — typically B=1.

    Returns a list of per-frame outputs containing `depth` and the running
    `kalman_variance` map.
    """
    device = next(pipeline.parameters()).device
    flow_predict = _load_raft(config.flow_model, device) if config.flow_model != "none" else None

    kf: TemporalKalmanFilter | None = None
    prev_image: torch.Tensor | None = None
    out_frames: list[dict] = []

    for frame in frames:
        img = frame["image"].to(device)
        # Lazily build the Kalman filter once we know the resolution.
        if kf is None:
            kf = TemporalKalmanFilter(
                shape=(img.shape[0], 1, img.shape[-2], img.shape[-1]),
                device=device,
                config=config.temporal,
            )

        # Predict: warp running state by optical flow and inflate variance.
        if prev_image is not None and flow_predict is not None:
            f_fwd, f_bwd = _compute_flow(flow_predict, prev_image, img, config.flow_resize)
            kf.predict(flow_fwd=f_fwd, flow_bwd=f_bwd)

        # Pass current Kalman prior into the per-frame inference.
        frame_with_prior = dict(frame)
        if kf.has_state:
            frame_with_prior["kalman_mu_prior"] = kf.mu
            frame_with_prior["kalman_P_prior"] = kf.P
        out = pipeline.forward_test(frame_with_prior)

        # Absorb the per-frame posterior into the temporal state — gives the
        # next frame a tighter prior than just sparse-LiDAR alone.
        kf.absorb_measurement(
            mu_meas=out["depth"] - 0.5,                # back to normalized space
            P_meas=out["kalman_variance"],
        )
        out["kalman_variance_running"] = kf.P.clone()
        out_frames.append(out)
        prev_image = img

    return out_frames