LiDAR-Perfect-Depth / code /ppd /lpd /uncertainty_modulation.py
chenming-wu's picture
code
436b829 verified
"""
Uncertainty-guided prompt modulation — paper §3.5 (Eq. 7).
The Kalman variance feeds back into the prompt-encoder density signal:
ρ̃(p) = ρ(p) · (1 + P(p) / max_q P(q))
high-uncertainty regions amplify the prompt gate (so sparse LiDAR draws more
attention there); low-uncertainty regions trust the temporal prior. No new
parameters are introduced; ρ̃ replaces ρ inside `PromptGate`.
Variance is taken at the prompt-token grid (downsampled to match) so the
per-token modulation lines up with `SparsePromptEncoder` outputs.
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
def modulate_density(
density_tokens: torch.Tensor, P_full: torch.Tensor, *, eps: float = 1e-6
) -> torch.Tensor:
"""Apply Eq. 7.
density_tokens : (B, T, 1) — output of SparsePromptEncoder.
P_full : (B, 1, H, W) — current Kalman variance map at full res.
Returns ρ̃ with the same shape as `density_tokens`.
"""
if P_full is None:
return density_tokens
B, T, _ = density_tokens.shape
# Reduce spatially to T tokens. We don't know the grid here — derive from T.
# Caller can pass an already-downsampled P; we accept either.
if P_full.shape[-2] * P_full.shape[-1] != T:
# Pool to match token count using adaptive avg pool to a square-ish grid.
# We assume T = h * w with h ≈ H_full/k, w ≈ W_full/k. Recover h,w from
# the aspect ratio of P_full.
h_full, w_full = P_full.shape[-2:]
ratio = w_full / max(h_full, 1)
h = int(round((T / max(ratio, 1e-6)) ** 0.5))
w = T // max(h, 1)
h = max(1, min(h, h_full))
w = max(1, min(w, w_full))
if h * w != T:
# fallback: square grid that matches T
h = int(T ** 0.5)
w = T // max(h, 1)
P = F.adaptive_avg_pool2d(P_full, output_size=(h, w))
P = P.flatten(2).transpose(1, 2) # (B, T, 1)
else:
P = P_full.flatten(2).transpose(1, 2)
# Per-sample max for normalization
P_max = P.amax(dim=1, keepdim=True).clamp_min(eps)
return density_tokens * (1.0 + P / P_max)