""" Sparse-LiDAR prompt encoder. Per-pixel sparse depth (B,1,H,W) + binary mask (B,1,H,W) are pooled to multiple scales via masked average pooling. At each scale we keep both the pooled depth and the *density* (fraction of observed pixels per cell) — paper §3.1 calls this the per-token confidence signal that drives the prompt gate. The output token grid is sized to match the DiT's stage-2 token grid (H/p, W/p), which is where prompt fusion happens. """ from __future__ import annotations import math from typing import Iterable import torch import torch.nn as nn import torch.nn.functional as F def masked_avg_pool(depth: torch.Tensor, mask: torch.Tensor, kernel: int) -> tuple[torch.Tensor, torch.Tensor]: """Returns (pooled_depth, density). `mask` is bool/0-1. Both inputs (B,1,H,W).""" m = mask.float() summed = F.avg_pool2d(depth * m, kernel_size=kernel, stride=kernel, ceil_mode=False) * (kernel * kernel) count = F.avg_pool2d(m, kernel_size=kernel, stride=kernel, ceil_mode=False) * (kernel * kernel) pooled = summed / count.clamp_min(1.0) density = count / (kernel * kernel) return pooled, density def quantile_log_normalize(depth: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Per-sample 2/98 quantile log-depth normalization, matches PPD's GT scheme. Returns normalized depth in roughly [-0.5, 0.5]. Pixels with mask == 0 are set to 0 so they look like "no observation" downstream. """ out = torch.zeros_like(depth) B = depth.shape[0] log_depth = torch.log(depth.clamp_min(0.0) + 1.0) for i in range(B): m = mask[i].bool() if m.sum() == 0: continue vals = log_depth[i][m] d_min = torch.quantile(vals, 0.02) d_max = torch.quantile(vals, 0.98) if (d_max - d_min) < 1e-6: d_max = d_min + 1e-6 norm = (log_depth[i] - d_min) / (d_max - d_min) - 0.5 norm = torch.clamp(norm, -0.5, 1.0) out[i] = norm * m.float() return out class _SmallCNN(nn.Module): def __init__(self, in_ch: int, hidden: int, out_ch: int): super().__init__() self.net = nn.Sequential( nn.Conv2d(in_ch, hidden, kernel_size=3, padding=1), nn.GELU(), nn.Conv2d(hidden, out_ch, kernel_size=3, padding=1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class SparsePromptEncoder(nn.Module): """Multi-scale sparse-prompt encoder. Args ---- scales : pool kernels (in pixels). Paper §3.1 uses {4, 8, 16, 32} — kernel=4 gives sub-token granularity (4×4 pixels per cell), kernel=32 gives global context. All scales are bilinearly resampled to the DiT stage-2 token grid before fusion. embed_dim : output token embedding dim (matches the DiT's hidden_size). out_grid_div : the model fuses prompts at the stage-2 grid which is H/p2, W/p2 with p2 = 8 by default. """ def __init__( self, scales: Iterable[int] = (4, 8, 16, 32), embed_dim: int = 1024, out_grid_div: int = 8, hidden: int = 128, ): super().__init__() self.scales = tuple(scales) self.embed_dim = embed_dim self.out_grid_div = out_grid_div # 2 channels per scale (depth + density) → CNN → embed_dim self.per_scale = nn.ModuleList( [_SmallCNN(2, hidden, embed_dim) for _ in self.scales] ) # final mixer over concatenated multi-scale features self.fuse = nn.Linear(embed_dim * len(self.scales), embed_dim) # zero-init the final projection so untrained model behaves like PPD nn.init.zeros_(self.fuse.weight) nn.init.zeros_(self.fuse.bias) def forward( self, sparse_depth: torch.Tensor, sparse_mask: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Returns (tokens, density_per_token). tokens: (B, T, embed_dim) density_per_token: (B, T, 1) — averaged density across scales, used by the prompt gate as a confidence weight. """ # Normalize sparse depth once at the input resolution. norm_depth = quantile_log_normalize(sparse_depth, sparse_mask) B, _, H, W = sparse_depth.shape out_h, out_w = H // self.out_grid_div, W // self.out_grid_div feats: list[torch.Tensor] = [] densities: list[torch.Tensor] = [] for cnn, k in zip(self.per_scale, self.scales): pooled, density = masked_avg_pool(norm_depth, sparse_mask, kernel=k) x = torch.cat([pooled, density], dim=1) x = cnn(x) x = F.interpolate(x, size=(out_h, out_w), mode="bilinear", align_corners=False) d = F.interpolate(density, size=(out_h, out_w), mode="bilinear", align_corners=False) feats.append(x) densities.append(d) x = torch.cat(feats, dim=1) # (B, embed_dim*len(scales), out_h, out_w) x = x.flatten(2).transpose(1, 2) # (B, T, embed_dim*len(scales)) x = self.fuse(x) # (B, T, embed_dim) density = torch.stack(densities, dim=0).mean(dim=0) # (B,1,out_h,out_w) density = density.flatten(2).transpose(1, 2) # (B, T, 1) return x, density