File size: 5,101 Bytes
436b829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
Video inference wrapper.
Runs LPD frame-by-frame, threading a `TemporalKalmanFilter` between frames.
Optical flow is recomputed per-frame with RAFT (torchvision) — the variance
inflation in the predict step is driven by forward-backward consistency, so
flow noise self-regulates. The Kalman state's posterior variance is fed
back as a per-token confidence to the prompt encoder via
`uncertainty_modulation.modulate_density`.
This module is inference-only (paper §3.6 — temporal mechanisms require no
extra training).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable
import torch
import torch.nn.functional as F
from ppd.lpd.lpd_train import LiDARPerfectDepth
from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig
@dataclass
class VideoInferenceConfig:
flow_model: str = "raft_large" # torchvision raft_large or "none"
raft_iters: int = 12
flow_resize: int | None = 384 # downsample to this short side for RAFT, then upsample
temporal: TemporalKalmanConfig = TemporalKalmanConfig()
def _load_raft(model_name: str, device: torch.device):
"""Lazy-load torchvision RAFT and return a callable (img1, img2) -> flow."""
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
weights = Raft_Large_Weights.DEFAULT
model = raft_large(weights=weights).to(device).eval()
transforms = weights.transforms()
def predict(img1: torch.Tensor, img2: torch.Tensor, iters: int = 12) -> torch.Tensor:
# img1, img2: (B, 3, H, W) in [0, 1]
img1, img2 = transforms(img1, img2)
with torch.no_grad():
flow = model(img1, img2, num_flow_updates=iters)[-1]
return flow
return predict
def _compute_flow(predict_fn, img_prev: torch.Tensor, img_curr: torch.Tensor, resize: int | None):
"""Compute forward + backward flow with optional downsampled solver."""
H, W = img_prev.shape[-2:]
if resize is not None and min(H, W) > resize:
scale = resize / min(H, W)
new_h = int(round(H * scale / 8) * 8)
new_w = int(round(W * scale / 8) * 8)
ip = F.interpolate(img_prev, size=(new_h, new_w), mode="bilinear", align_corners=False)
ic = F.interpolate(img_curr, size=(new_h, new_w), mode="bilinear", align_corners=False)
f_fwd = predict_fn(ip, ic)
f_bwd = predict_fn(ic, ip)
# rescale flow to original resolution
f_fwd = F.interpolate(f_fwd, size=(H, W), mode="bilinear", align_corners=False)
f_bwd = F.interpolate(f_bwd, size=(H, W), mode="bilinear", align_corners=False)
f_fwd[:, 0] *= W / new_w
f_fwd[:, 1] *= H / new_h
f_bwd[:, 0] *= W / new_w
f_bwd[:, 1] *= H / new_h
else:
f_fwd = predict_fn(img_prev, img_curr)
f_bwd = predict_fn(img_curr, img_prev)
return f_fwd, f_bwd
@torch.no_grad()
def run_video(
pipeline: LiDARPerfectDepth,
frames: Iterable[dict],
*,
config: VideoInferenceConfig = VideoInferenceConfig(),
) -> list[dict]:
"""Run LPD over a sequence of per-frame batches.
`frames` yields dicts with at least `image`, optionally `sparse_depth` +
`sparse_mask`, optionally `depth` + `mask` (for simulating sparse).
Each dict is a (B,3,H,W) batch — typically B=1.
Returns a list of per-frame outputs containing `depth` and the running
`kalman_variance` map.
"""
device = next(pipeline.parameters()).device
flow_predict = _load_raft(config.flow_model, device) if config.flow_model != "none" else None
kf: TemporalKalmanFilter | None = None
prev_image: torch.Tensor | None = None
out_frames: list[dict] = []
for frame in frames:
img = frame["image"].to(device)
# Lazily build the Kalman filter once we know the resolution.
if kf is None:
kf = TemporalKalmanFilter(
shape=(img.shape[0], 1, img.shape[-2], img.shape[-1]),
device=device,
config=config.temporal,
)
# Predict: warp running state by optical flow and inflate variance.
if prev_image is not None and flow_predict is not None:
f_fwd, f_bwd = _compute_flow(flow_predict, prev_image, img, config.flow_resize)
kf.predict(flow_fwd=f_fwd, flow_bwd=f_bwd)
# Pass current Kalman prior into the per-frame inference.
frame_with_prior = dict(frame)
if kf.has_state:
frame_with_prior["kalman_mu_prior"] = kf.mu
frame_with_prior["kalman_P_prior"] = kf.P
out = pipeline.forward_test(frame_with_prior)
# Absorb the per-frame posterior into the temporal state — gives the
# next frame a tighter prior than just sparse-LiDAR alone.
kf.absorb_measurement(
mu_meas=out["depth"] - 0.5, # back to normalized space
P_meas=out["kalman_variance"],
)
out["kalman_variance_running"] = kf.P.clone()
out_frames.append(out)
prev_image = img
return out_frames
|