""" Uncertainty-guided prompt modulation — paper §3.5 (Eq. 7). The Kalman variance feeds back into the prompt-encoder density signal: ρ̃(p) = ρ(p) · (1 + P(p) / max_q P(q)) high-uncertainty regions amplify the prompt gate (so sparse LiDAR draws more attention there); low-uncertainty regions trust the temporal prior. No new parameters are introduced; ρ̃ replaces ρ inside `PromptGate`. Variance is taken at the prompt-token grid (downsampled to match) so the per-token modulation lines up with `SparsePromptEncoder` outputs. """ from __future__ import annotations import torch import torch.nn.functional as F def modulate_density( density_tokens: torch.Tensor, P_full: torch.Tensor, *, eps: float = 1e-6 ) -> torch.Tensor: """Apply Eq. 7. density_tokens : (B, T, 1) — output of SparsePromptEncoder. P_full : (B, 1, H, W) — current Kalman variance map at full res. Returns ρ̃ with the same shape as `density_tokens`. """ if P_full is None: return density_tokens B, T, _ = density_tokens.shape # Reduce spatially to T tokens. We don't know the grid here — derive from T. # Caller can pass an already-downsampled P; we accept either. if P_full.shape[-2] * P_full.shape[-1] != T: # Pool to match token count using adaptive avg pool to a square-ish grid. # We assume T = h * w with h ≈ H_full/k, w ≈ W_full/k. Recover h,w from # the aspect ratio of P_full. h_full, w_full = P_full.shape[-2:] ratio = w_full / max(h_full, 1) h = int(round((T / max(ratio, 1e-6)) ** 0.5)) w = T // max(h, 1) h = max(1, min(h, h_full)) w = max(1, min(w, w_full)) if h * w != T: # fallback: square grid that matches T h = int(T ** 0.5) w = T // max(h, 1) P = F.adaptive_avg_pool2d(P_full, output_size=(h, w)) P = P.flatten(2).transpose(1, 2) # (B, T, 1) else: P = P_full.flatten(2).transpose(1, 2) # Per-sample max for normalization P_max = P.amax(dim=1, keepdim=True).clamp_min(eps) return density_tokens * (1.0 + P / P_max)