| """ |
| Video inference wrapper. |
| |
| Runs LPD frame-by-frame, threading a `TemporalKalmanFilter` between frames. |
| Optical flow is recomputed per-frame with RAFT (torchvision) — the variance |
| inflation in the predict step is driven by forward-backward consistency, so |
| flow noise self-regulates. The Kalman state's posterior variance is fed |
| back as a per-token confidence to the prompt encoder via |
| `uncertainty_modulation.modulate_density`. |
| |
| This module is inference-only (paper §3.6 — temporal mechanisms require no |
| extra training). |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Iterable |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from ppd.lpd.lpd_train import LiDARPerfectDepth |
| from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig |
|
|
|
|
| @dataclass |
| class VideoInferenceConfig: |
| flow_model: str = "raft_large" |
| raft_iters: int = 12 |
| flow_resize: int | None = 384 |
| temporal: TemporalKalmanConfig = TemporalKalmanConfig() |
|
|
|
|
| def _load_raft(model_name: str, device: torch.device): |
| """Lazy-load torchvision RAFT and return a callable (img1, img2) -> flow.""" |
| from torchvision.models.optical_flow import raft_large, Raft_Large_Weights |
| weights = Raft_Large_Weights.DEFAULT |
| model = raft_large(weights=weights).to(device).eval() |
| transforms = weights.transforms() |
|
|
| def predict(img1: torch.Tensor, img2: torch.Tensor, iters: int = 12) -> torch.Tensor: |
| |
| img1, img2 = transforms(img1, img2) |
| with torch.no_grad(): |
| flow = model(img1, img2, num_flow_updates=iters)[-1] |
| return flow |
|
|
| return predict |
|
|
|
|
| def _compute_flow(predict_fn, img_prev: torch.Tensor, img_curr: torch.Tensor, resize: int | None): |
| """Compute forward + backward flow with optional downsampled solver.""" |
| H, W = img_prev.shape[-2:] |
| if resize is not None and min(H, W) > resize: |
| scale = resize / min(H, W) |
| new_h = int(round(H * scale / 8) * 8) |
| new_w = int(round(W * scale / 8) * 8) |
| ip = F.interpolate(img_prev, size=(new_h, new_w), mode="bilinear", align_corners=False) |
| ic = F.interpolate(img_curr, size=(new_h, new_w), mode="bilinear", align_corners=False) |
| f_fwd = predict_fn(ip, ic) |
| f_bwd = predict_fn(ic, ip) |
| |
| f_fwd = F.interpolate(f_fwd, size=(H, W), mode="bilinear", align_corners=False) |
| f_bwd = F.interpolate(f_bwd, size=(H, W), mode="bilinear", align_corners=False) |
| f_fwd[:, 0] *= W / new_w |
| f_fwd[:, 1] *= H / new_h |
| f_bwd[:, 0] *= W / new_w |
| f_bwd[:, 1] *= H / new_h |
| else: |
| f_fwd = predict_fn(img_prev, img_curr) |
| f_bwd = predict_fn(img_curr, img_prev) |
| return f_fwd, f_bwd |
|
|
|
|
| @torch.no_grad() |
| def run_video( |
| pipeline: LiDARPerfectDepth, |
| frames: Iterable[dict], |
| *, |
| config: VideoInferenceConfig = VideoInferenceConfig(), |
| ) -> list[dict]: |
| """Run LPD over a sequence of per-frame batches. |
| |
| `frames` yields dicts with at least `image`, optionally `sparse_depth` + |
| `sparse_mask`, optionally `depth` + `mask` (for simulating sparse). |
| Each dict is a (B,3,H,W) batch — typically B=1. |
| |
| Returns a list of per-frame outputs containing `depth` and the running |
| `kalman_variance` map. |
| """ |
| device = next(pipeline.parameters()).device |
| flow_predict = _load_raft(config.flow_model, device) if config.flow_model != "none" else None |
|
|
| kf: TemporalKalmanFilter | None = None |
| prev_image: torch.Tensor | None = None |
| out_frames: list[dict] = [] |
|
|
| for frame in frames: |
| img = frame["image"].to(device) |
| |
| if kf is None: |
| kf = TemporalKalmanFilter( |
| shape=(img.shape[0], 1, img.shape[-2], img.shape[-1]), |
| device=device, |
| config=config.temporal, |
| ) |
|
|
| |
| if prev_image is not None and flow_predict is not None: |
| f_fwd, f_bwd = _compute_flow(flow_predict, prev_image, img, config.flow_resize) |
| kf.predict(flow_fwd=f_fwd, flow_bwd=f_bwd) |
|
|
| |
| frame_with_prior = dict(frame) |
| if kf.has_state: |
| frame_with_prior["kalman_mu_prior"] = kf.mu |
| frame_with_prior["kalman_P_prior"] = kf.P |
| out = pipeline.forward_test(frame_with_prior) |
|
|
| |
| |
| kf.absorb_measurement( |
| mu_meas=out["depth"] - 0.5, |
| P_meas=out["kalman_variance"], |
| ) |
| out["kalman_variance_running"] = kf.P.clone() |
| out_frames.append(out) |
| prev_image = img |
|
|
| return out_frames |
|
|