| """ |
| Uncertainty-guided prompt modulation — paper §3.5 (Eq. 7). |
| |
| The Kalman variance feeds back into the prompt-encoder density signal: |
| |
| ρ̃(p) = ρ(p) · (1 + P(p) / max_q P(q)) |
| |
| high-uncertainty regions amplify the prompt gate (so sparse LiDAR draws more |
| attention there); low-uncertainty regions trust the temporal prior. No new |
| parameters are introduced; ρ̃ replaces ρ inside `PromptGate`. |
| |
| Variance is taken at the prompt-token grid (downsampled to match) so the |
| per-token modulation lines up with `SparsePromptEncoder` outputs. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def modulate_density( |
| density_tokens: torch.Tensor, P_full: torch.Tensor, *, eps: float = 1e-6 |
| ) -> torch.Tensor: |
| """Apply Eq. 7. |
| |
| density_tokens : (B, T, 1) — output of SparsePromptEncoder. |
| P_full : (B, 1, H, W) — current Kalman variance map at full res. |
| |
| Returns ρ̃ with the same shape as `density_tokens`. |
| """ |
| if P_full is None: |
| return density_tokens |
| B, T, _ = density_tokens.shape |
| |
| |
| if P_full.shape[-2] * P_full.shape[-1] != T: |
| |
| |
| |
| h_full, w_full = P_full.shape[-2:] |
| ratio = w_full / max(h_full, 1) |
| h = int(round((T / max(ratio, 1e-6)) ** 0.5)) |
| w = T // max(h, 1) |
| h = max(1, min(h, h_full)) |
| w = max(1, min(w, w_full)) |
| if h * w != T: |
| |
| h = int(T ** 0.5) |
| w = T // max(h, 1) |
| P = F.adaptive_avg_pool2d(P_full, output_size=(h, w)) |
| P = P.flatten(2).transpose(1, 2) |
| else: |
| P = P_full.flatten(2).transpose(1, 2) |
|
|
| |
| P_max = P.amax(dim=1, keepdim=True).clamp_min(eps) |
| return density_tokens * (1.0 + P / P_max) |
|
|