LiDAR-Perfect-Depth / code /ppd /lpd /lpd_video.py
chenming-wu's picture
code
436b829 verified
"""
Video inference wrapper.
Runs LPD frame-by-frame, threading a `TemporalKalmanFilter` between frames.
Optical flow is recomputed per-frame with RAFT (torchvision) — the variance
inflation in the predict step is driven by forward-backward consistency, so
flow noise self-regulates. The Kalman state's posterior variance is fed
back as a per-token confidence to the prompt encoder via
`uncertainty_modulation.modulate_density`.
This module is inference-only (paper §3.6 — temporal mechanisms require no
extra training).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable
import torch
import torch.nn.functional as F
from ppd.lpd.lpd_train import LiDARPerfectDepth
from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig
@dataclass
class VideoInferenceConfig:
flow_model: str = "raft_large" # torchvision raft_large or "none"
raft_iters: int = 12
flow_resize: int | None = 384 # downsample to this short side for RAFT, then upsample
temporal: TemporalKalmanConfig = TemporalKalmanConfig()
def _load_raft(model_name: str, device: torch.device):
"""Lazy-load torchvision RAFT and return a callable (img1, img2) -> flow."""
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
weights = Raft_Large_Weights.DEFAULT
model = raft_large(weights=weights).to(device).eval()
transforms = weights.transforms()
def predict(img1: torch.Tensor, img2: torch.Tensor, iters: int = 12) -> torch.Tensor:
# img1, img2: (B, 3, H, W) in [0, 1]
img1, img2 = transforms(img1, img2)
with torch.no_grad():
flow = model(img1, img2, num_flow_updates=iters)[-1]
return flow
return predict
def _compute_flow(predict_fn, img_prev: torch.Tensor, img_curr: torch.Tensor, resize: int | None):
"""Compute forward + backward flow with optional downsampled solver."""
H, W = img_prev.shape[-2:]
if resize is not None and min(H, W) > resize:
scale = resize / min(H, W)
new_h = int(round(H * scale / 8) * 8)
new_w = int(round(W * scale / 8) * 8)
ip = F.interpolate(img_prev, size=(new_h, new_w), mode="bilinear", align_corners=False)
ic = F.interpolate(img_curr, size=(new_h, new_w), mode="bilinear", align_corners=False)
f_fwd = predict_fn(ip, ic)
f_bwd = predict_fn(ic, ip)
# rescale flow to original resolution
f_fwd = F.interpolate(f_fwd, size=(H, W), mode="bilinear", align_corners=False)
f_bwd = F.interpolate(f_bwd, size=(H, W), mode="bilinear", align_corners=False)
f_fwd[:, 0] *= W / new_w
f_fwd[:, 1] *= H / new_h
f_bwd[:, 0] *= W / new_w
f_bwd[:, 1] *= H / new_h
else:
f_fwd = predict_fn(img_prev, img_curr)
f_bwd = predict_fn(img_curr, img_prev)
return f_fwd, f_bwd
@torch.no_grad()
def run_video(
pipeline: LiDARPerfectDepth,
frames: Iterable[dict],
*,
config: VideoInferenceConfig = VideoInferenceConfig(),
) -> list[dict]:
"""Run LPD over a sequence of per-frame batches.
`frames` yields dicts with at least `image`, optionally `sparse_depth` +
`sparse_mask`, optionally `depth` + `mask` (for simulating sparse).
Each dict is a (B,3,H,W) batch — typically B=1.
Returns a list of per-frame outputs containing `depth` and the running
`kalman_variance` map.
"""
device = next(pipeline.parameters()).device
flow_predict = _load_raft(config.flow_model, device) if config.flow_model != "none" else None
kf: TemporalKalmanFilter | None = None
prev_image: torch.Tensor | None = None
out_frames: list[dict] = []
for frame in frames:
img = frame["image"].to(device)
# Lazily build the Kalman filter once we know the resolution.
if kf is None:
kf = TemporalKalmanFilter(
shape=(img.shape[0], 1, img.shape[-2], img.shape[-1]),
device=device,
config=config.temporal,
)
# Predict: warp running state by optical flow and inflate variance.
if prev_image is not None and flow_predict is not None:
f_fwd, f_bwd = _compute_flow(flow_predict, prev_image, img, config.flow_resize)
kf.predict(flow_fwd=f_fwd, flow_bwd=f_bwd)
# Pass current Kalman prior into the per-frame inference.
frame_with_prior = dict(frame)
if kf.has_state:
frame_with_prior["kalman_mu_prior"] = kf.mu
frame_with_prior["kalman_P_prior"] = kf.P
out = pipeline.forward_test(frame_with_prior)
# Absorb the per-frame posterior into the temporal state — gives the
# next frame a tighter prior than just sparse-LiDAR alone.
kf.absorb_measurement(
mu_meas=out["depth"] - 0.5, # back to normalized space
P_meas=out["kalman_variance"],
)
out["kalman_variance_running"] = kf.P.clone()
out_frames.append(out)
prev_image = img
return out_frames