| """ |
| Sparse-LiDAR prompt encoder. |
| |
| Per-pixel sparse depth (B,1,H,W) + binary mask (B,1,H,W) are pooled to multiple |
| scales via masked average pooling. At each scale we keep both the pooled depth |
| and the *density* (fraction of observed pixels per cell) — paper §3.1 calls |
| this the per-token confidence signal that drives the prompt gate. |
| |
| The output token grid is sized to match the DiT's stage-2 token grid (H/p, W/p), |
| which is where prompt fusion happens. |
| """ |
| from __future__ import annotations |
|
|
| import math |
| from typing import Iterable |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def masked_avg_pool(depth: torch.Tensor, mask: torch.Tensor, kernel: int) -> tuple[torch.Tensor, torch.Tensor]: |
| """Returns (pooled_depth, density). `mask` is bool/0-1. Both inputs (B,1,H,W).""" |
| m = mask.float() |
| summed = F.avg_pool2d(depth * m, kernel_size=kernel, stride=kernel, ceil_mode=False) * (kernel * kernel) |
| count = F.avg_pool2d(m, kernel_size=kernel, stride=kernel, ceil_mode=False) * (kernel * kernel) |
| pooled = summed / count.clamp_min(1.0) |
| density = count / (kernel * kernel) |
| return pooled, density |
|
|
|
|
| def quantile_log_normalize(depth: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| """Per-sample 2/98 quantile log-depth normalization, matches PPD's GT scheme. |
| |
| Returns normalized depth in roughly [-0.5, 0.5]. Pixels with mask == 0 are |
| set to 0 so they look like "no observation" downstream. |
| """ |
| out = torch.zeros_like(depth) |
| B = depth.shape[0] |
| log_depth = torch.log(depth.clamp_min(0.0) + 1.0) |
| for i in range(B): |
| m = mask[i].bool() |
| if m.sum() == 0: |
| continue |
| vals = log_depth[i][m] |
| d_min = torch.quantile(vals, 0.02) |
| d_max = torch.quantile(vals, 0.98) |
| if (d_max - d_min) < 1e-6: |
| d_max = d_min + 1e-6 |
| norm = (log_depth[i] - d_min) / (d_max - d_min) - 0.5 |
| norm = torch.clamp(norm, -0.5, 1.0) |
| out[i] = norm * m.float() |
| return out |
|
|
|
|
| class _SmallCNN(nn.Module): |
| def __init__(self, in_ch: int, hidden: int, out_ch: int): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(in_ch, hidden, kernel_size=3, padding=1), |
| nn.GELU(), |
| nn.Conv2d(hidden, out_ch, kernel_size=3, padding=1), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class SparsePromptEncoder(nn.Module): |
| """Multi-scale sparse-prompt encoder. |
| |
| Args |
| ---- |
| scales : pool kernels (in pixels). Paper §3.1 uses {4, 8, 16, 32} — kernel=4 |
| gives sub-token granularity (4×4 pixels per cell), kernel=32 gives |
| global context. All scales are bilinearly resampled to the DiT |
| stage-2 token grid before fusion. |
| embed_dim : output token embedding dim (matches the DiT's hidden_size). |
| out_grid_div : the model fuses prompts at the stage-2 grid which is H/p2, |
| W/p2 with p2 = 8 by default. |
| """ |
|
|
| def __init__( |
| self, |
| scales: Iterable[int] = (4, 8, 16, 32), |
| embed_dim: int = 1024, |
| out_grid_div: int = 8, |
| hidden: int = 128, |
| ): |
| super().__init__() |
| self.scales = tuple(scales) |
| self.embed_dim = embed_dim |
| self.out_grid_div = out_grid_div |
| |
| self.per_scale = nn.ModuleList( |
| [_SmallCNN(2, hidden, embed_dim) for _ in self.scales] |
| ) |
| |
| self.fuse = nn.Linear(embed_dim * len(self.scales), embed_dim) |
| |
| nn.init.zeros_(self.fuse.weight) |
| nn.init.zeros_(self.fuse.bias) |
|
|
| def forward( |
| self, sparse_depth: torch.Tensor, sparse_mask: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Returns (tokens, density_per_token). |
| |
| tokens: (B, T, embed_dim) |
| density_per_token: (B, T, 1) — averaged density across scales, used by |
| the prompt gate as a confidence weight. |
| """ |
| |
| norm_depth = quantile_log_normalize(sparse_depth, sparse_mask) |
| B, _, H, W = sparse_depth.shape |
| out_h, out_w = H // self.out_grid_div, W // self.out_grid_div |
| feats: list[torch.Tensor] = [] |
| densities: list[torch.Tensor] = [] |
| for cnn, k in zip(self.per_scale, self.scales): |
| pooled, density = masked_avg_pool(norm_depth, sparse_mask, kernel=k) |
| x = torch.cat([pooled, density], dim=1) |
| x = cnn(x) |
| x = F.interpolate(x, size=(out_h, out_w), mode="bilinear", align_corners=False) |
| d = F.interpolate(density, size=(out_h, out_w), mode="bilinear", align_corners=False) |
| feats.append(x) |
| densities.append(d) |
| x = torch.cat(feats, dim=1) |
| x = x.flatten(2).transpose(1, 2) |
| x = self.fuse(x) |
| density = torch.stack(densities, dim=0).mean(dim=0) |
| density = density.flatten(2).transpose(1, 2) |
| return x, density |
|
|