""" Video inference wrapper. Runs LPD frame-by-frame, threading a `TemporalKalmanFilter` between frames. Optical flow is recomputed per-frame with RAFT (torchvision) — the variance inflation in the predict step is driven by forward-backward consistency, so flow noise self-regulates. The Kalman state's posterior variance is fed back as a per-token confidence to the prompt encoder via `uncertainty_modulation.modulate_density`. This module is inference-only (paper §3.6 — temporal mechanisms require no extra training). """ from __future__ import annotations from dataclasses import dataclass from typing import Iterable import torch import torch.nn.functional as F from ppd.lpd.lpd_train import LiDARPerfectDepth from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig @dataclass class VideoInferenceConfig: flow_model: str = "raft_large" # torchvision raft_large or "none" raft_iters: int = 12 flow_resize: int | None = 384 # downsample to this short side for RAFT, then upsample temporal: TemporalKalmanConfig = TemporalKalmanConfig() def _load_raft(model_name: str, device: torch.device): """Lazy-load torchvision RAFT and return a callable (img1, img2) -> flow.""" from torchvision.models.optical_flow import raft_large, Raft_Large_Weights weights = Raft_Large_Weights.DEFAULT model = raft_large(weights=weights).to(device).eval() transforms = weights.transforms() def predict(img1: torch.Tensor, img2: torch.Tensor, iters: int = 12) -> torch.Tensor: # img1, img2: (B, 3, H, W) in [0, 1] img1, img2 = transforms(img1, img2) with torch.no_grad(): flow = model(img1, img2, num_flow_updates=iters)[-1] return flow return predict def _compute_flow(predict_fn, img_prev: torch.Tensor, img_curr: torch.Tensor, resize: int | None): """Compute forward + backward flow with optional downsampled solver.""" H, W = img_prev.shape[-2:] if resize is not None and min(H, W) > resize: scale = resize / min(H, W) new_h = int(round(H * scale / 8) * 8) new_w = int(round(W * scale / 8) * 8) ip = F.interpolate(img_prev, size=(new_h, new_w), mode="bilinear", align_corners=False) ic = F.interpolate(img_curr, size=(new_h, new_w), mode="bilinear", align_corners=False) f_fwd = predict_fn(ip, ic) f_bwd = predict_fn(ic, ip) # rescale flow to original resolution f_fwd = F.interpolate(f_fwd, size=(H, W), mode="bilinear", align_corners=False) f_bwd = F.interpolate(f_bwd, size=(H, W), mode="bilinear", align_corners=False) f_fwd[:, 0] *= W / new_w f_fwd[:, 1] *= H / new_h f_bwd[:, 0] *= W / new_w f_bwd[:, 1] *= H / new_h else: f_fwd = predict_fn(img_prev, img_curr) f_bwd = predict_fn(img_curr, img_prev) return f_fwd, f_bwd @torch.no_grad() def run_video( pipeline: LiDARPerfectDepth, frames: Iterable[dict], *, config: VideoInferenceConfig = VideoInferenceConfig(), ) -> list[dict]: """Run LPD over a sequence of per-frame batches. `frames` yields dicts with at least `image`, optionally `sparse_depth` + `sparse_mask`, optionally `depth` + `mask` (for simulating sparse). Each dict is a (B,3,H,W) batch — typically B=1. Returns a list of per-frame outputs containing `depth` and the running `kalman_variance` map. """ device = next(pipeline.parameters()).device flow_predict = _load_raft(config.flow_model, device) if config.flow_model != "none" else None kf: TemporalKalmanFilter | None = None prev_image: torch.Tensor | None = None out_frames: list[dict] = [] for frame in frames: img = frame["image"].to(device) # Lazily build the Kalman filter once we know the resolution. if kf is None: kf = TemporalKalmanFilter( shape=(img.shape[0], 1, img.shape[-2], img.shape[-1]), device=device, config=config.temporal, ) # Predict: warp running state by optical flow and inflate variance. if prev_image is not None and flow_predict is not None: f_fwd, f_bwd = _compute_flow(flow_predict, prev_image, img, config.flow_resize) kf.predict(flow_fwd=f_fwd, flow_bwd=f_bwd) # Pass current Kalman prior into the per-frame inference. frame_with_prior = dict(frame) if kf.has_state: frame_with_prior["kalman_mu_prior"] = kf.mu frame_with_prior["kalman_P_prior"] = kf.P out = pipeline.forward_test(frame_with_prior) # Absorb the per-frame posterior into the temporal state — gives the # next frame a tighter prior than just sparse-LiDAR alone. kf.absorb_measurement( mu_meas=out["depth"] - 0.5, # back to normalized space P_meas=out["kalman_variance"], ) out["kalman_variance_running"] = kf.P.clone() out_frames.append(out) prev_image = img return out_frames