File size: 5,294 Bytes
436b829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """
Sparse-LiDAR prompt encoder.
Per-pixel sparse depth (B,1,H,W) + binary mask (B,1,H,W) are pooled to multiple
scales via masked average pooling. At each scale we keep both the pooled depth
and the *density* (fraction of observed pixels per cell) — paper §3.1 calls
this the per-token confidence signal that drives the prompt gate.
The output token grid is sized to match the DiT's stage-2 token grid (H/p, W/p),
which is where prompt fusion happens.
"""
from __future__ import annotations
import math
from typing import Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
def masked_avg_pool(depth: torch.Tensor, mask: torch.Tensor, kernel: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (pooled_depth, density). `mask` is bool/0-1. Both inputs (B,1,H,W)."""
m = mask.float()
summed = F.avg_pool2d(depth * m, kernel_size=kernel, stride=kernel, ceil_mode=False) * (kernel * kernel)
count = F.avg_pool2d(m, kernel_size=kernel, stride=kernel, ceil_mode=False) * (kernel * kernel)
pooled = summed / count.clamp_min(1.0)
density = count / (kernel * kernel)
return pooled, density
def quantile_log_normalize(depth: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Per-sample 2/98 quantile log-depth normalization, matches PPD's GT scheme.
Returns normalized depth in roughly [-0.5, 0.5]. Pixels with mask == 0 are
set to 0 so they look like "no observation" downstream.
"""
out = torch.zeros_like(depth)
B = depth.shape[0]
log_depth = torch.log(depth.clamp_min(0.0) + 1.0)
for i in range(B):
m = mask[i].bool()
if m.sum() == 0:
continue
vals = log_depth[i][m]
d_min = torch.quantile(vals, 0.02)
d_max = torch.quantile(vals, 0.98)
if (d_max - d_min) < 1e-6:
d_max = d_min + 1e-6
norm = (log_depth[i] - d_min) / (d_max - d_min) - 0.5
norm = torch.clamp(norm, -0.5, 1.0)
out[i] = norm * m.float()
return out
class _SmallCNN(nn.Module):
def __init__(self, in_ch: int, hidden: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, hidden, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv2d(hidden, out_ch, kernel_size=3, padding=1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class SparsePromptEncoder(nn.Module):
"""Multi-scale sparse-prompt encoder.
Args
----
scales : pool kernels (in pixels). Paper §3.1 uses {4, 8, 16, 32} — kernel=4
gives sub-token granularity (4×4 pixels per cell), kernel=32 gives
global context. All scales are bilinearly resampled to the DiT
stage-2 token grid before fusion.
embed_dim : output token embedding dim (matches the DiT's hidden_size).
out_grid_div : the model fuses prompts at the stage-2 grid which is H/p2,
W/p2 with p2 = 8 by default.
"""
def __init__(
self,
scales: Iterable[int] = (4, 8, 16, 32),
embed_dim: int = 1024,
out_grid_div: int = 8,
hidden: int = 128,
):
super().__init__()
self.scales = tuple(scales)
self.embed_dim = embed_dim
self.out_grid_div = out_grid_div
# 2 channels per scale (depth + density) → CNN → embed_dim
self.per_scale = nn.ModuleList(
[_SmallCNN(2, hidden, embed_dim) for _ in self.scales]
)
# final mixer over concatenated multi-scale features
self.fuse = nn.Linear(embed_dim * len(self.scales), embed_dim)
# zero-init the final projection so untrained model behaves like PPD
nn.init.zeros_(self.fuse.weight)
nn.init.zeros_(self.fuse.bias)
def forward(
self, sparse_depth: torch.Tensor, sparse_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (tokens, density_per_token).
tokens: (B, T, embed_dim)
density_per_token: (B, T, 1) — averaged density across scales, used by
the prompt gate as a confidence weight.
"""
# Normalize sparse depth once at the input resolution.
norm_depth = quantile_log_normalize(sparse_depth, sparse_mask)
B, _, H, W = sparse_depth.shape
out_h, out_w = H // self.out_grid_div, W // self.out_grid_div
feats: list[torch.Tensor] = []
densities: list[torch.Tensor] = []
for cnn, k in zip(self.per_scale, self.scales):
pooled, density = masked_avg_pool(norm_depth, sparse_mask, kernel=k)
x = torch.cat([pooled, density], dim=1)
x = cnn(x)
x = F.interpolate(x, size=(out_h, out_w), mode="bilinear", align_corners=False)
d = F.interpolate(density, size=(out_h, out_w), mode="bilinear", align_corners=False)
feats.append(x)
densities.append(d)
x = torch.cat(feats, dim=1) # (B, embed_dim*len(scales), out_h, out_w)
x = x.flatten(2).transpose(1, 2) # (B, T, embed_dim*len(scales))
x = self.fuse(x) # (B, T, embed_dim)
density = torch.stack(densities, dim=0).mean(dim=0) # (B,1,out_h,out_w)
density = density.flatten(2).transpose(1, 2) # (B, T, 1)
return x, density
|