File size: 2,154 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Uncertainty-guided prompt modulation — paper §3.5 (Eq. 7).

The Kalman variance feeds back into the prompt-encoder density signal:

    ρ̃(p) = ρ(p) · (1 + P(p) / max_q P(q))

high-uncertainty regions amplify the prompt gate (so sparse LiDAR draws more
attention there); low-uncertainty regions trust the temporal prior. No new
parameters are introduced; ρ̃ replaces ρ inside `PromptGate`.

Variance is taken at the prompt-token grid (downsampled to match) so the
per-token modulation lines up with `SparsePromptEncoder` outputs.
"""
from __future__ import annotations

import torch
import torch.nn.functional as F


def modulate_density(
    density_tokens: torch.Tensor, P_full: torch.Tensor, *, eps: float = 1e-6
) -> torch.Tensor:
    """Apply Eq. 7.

    density_tokens : (B, T, 1) — output of SparsePromptEncoder.
    P_full         : (B, 1, H, W) — current Kalman variance map at full res.

    Returns ρ̃ with the same shape as `density_tokens`.
    """
    if P_full is None:
        return density_tokens
    B, T, _ = density_tokens.shape
    # Reduce spatially to T tokens. We don't know the grid here — derive from T.
    # Caller can pass an already-downsampled P; we accept either.
    if P_full.shape[-2] * P_full.shape[-1] != T:
        # Pool to match token count using adaptive avg pool to a square-ish grid.
        # We assume T = h * w with h ≈ H_full/k, w ≈ W_full/k. Recover h,w from
        # the aspect ratio of P_full.
        h_full, w_full = P_full.shape[-2:]
        ratio = w_full / max(h_full, 1)
        h = int(round((T / max(ratio, 1e-6)) ** 0.5))
        w = T // max(h, 1)
        h = max(1, min(h, h_full))
        w = max(1, min(w, w_full))
        if h * w != T:
            # fallback: square grid that matches T
            h = int(T ** 0.5)
            w = T // max(h, 1)
        P = F.adaptive_avg_pool2d(P_full, output_size=(h, w))
        P = P.flatten(2).transpose(1, 2)  # (B, T, 1)
    else:
        P = P_full.flatten(2).transpose(1, 2)

    # Per-sample max for normalization
    P_max = P.amax(dim=1, keepdim=True).clamp_min(eps)
    return density_tokens * (1.0 + P / P_max)