File size: 5,294 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Sparse-LiDAR prompt encoder.

Per-pixel sparse depth (B,1,H,W) + binary mask (B,1,H,W) are pooled to multiple
scales via masked average pooling. At each scale we keep both the pooled depth
and the *density* (fraction of observed pixels per cell) — paper §3.1 calls
this the per-token confidence signal that drives the prompt gate.

The output token grid is sized to match the DiT's stage-2 token grid (H/p, W/p),
which is where prompt fusion happens.
"""
from __future__ import annotations

import math
from typing import Iterable

import torch
import torch.nn as nn
import torch.nn.functional as F


def masked_avg_pool(depth: torch.Tensor, mask: torch.Tensor, kernel: int) -> tuple[torch.Tensor, torch.Tensor]:
    """Returns (pooled_depth, density). `mask` is bool/0-1. Both inputs (B,1,H,W)."""
    m = mask.float()
    summed = F.avg_pool2d(depth * m, kernel_size=kernel, stride=kernel, ceil_mode=False) * (kernel * kernel)
    count = F.avg_pool2d(m, kernel_size=kernel, stride=kernel, ceil_mode=False) * (kernel * kernel)
    pooled = summed / count.clamp_min(1.0)
    density = count / (kernel * kernel)
    return pooled, density


def quantile_log_normalize(depth: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """Per-sample 2/98 quantile log-depth normalization, matches PPD's GT scheme.

    Returns normalized depth in roughly [-0.5, 0.5]. Pixels with mask == 0 are
    set to 0 so they look like "no observation" downstream.
    """
    out = torch.zeros_like(depth)
    B = depth.shape[0]
    log_depth = torch.log(depth.clamp_min(0.0) + 1.0)
    for i in range(B):
        m = mask[i].bool()
        if m.sum() == 0:
            continue
        vals = log_depth[i][m]
        d_min = torch.quantile(vals, 0.02)
        d_max = torch.quantile(vals, 0.98)
        if (d_max - d_min) < 1e-6:
            d_max = d_min + 1e-6
        norm = (log_depth[i] - d_min) / (d_max - d_min) - 0.5
        norm = torch.clamp(norm, -0.5, 1.0)
        out[i] = norm * m.float()
    return out


class _SmallCNN(nn.Module):
    def __init__(self, in_ch: int, hidden: int, out_ch: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, hidden, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(hidden, out_ch, kernel_size=3, padding=1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class SparsePromptEncoder(nn.Module):
    """Multi-scale sparse-prompt encoder.

    Args
    ----
    scales : pool kernels (in pixels). Paper §3.1 uses {4, 8, 16, 32} — kernel=4
        gives sub-token granularity (4×4 pixels per cell), kernel=32 gives
        global context. All scales are bilinearly resampled to the DiT
        stage-2 token grid before fusion.
    embed_dim : output token embedding dim (matches the DiT's hidden_size).
    out_grid_div : the model fuses prompts at the stage-2 grid which is H/p2,
        W/p2 with p2 = 8 by default.
    """

    def __init__(
        self,
        scales: Iterable[int] = (4, 8, 16, 32),
        embed_dim: int = 1024,
        out_grid_div: int = 8,
        hidden: int = 128,
    ):
        super().__init__()
        self.scales = tuple(scales)
        self.embed_dim = embed_dim
        self.out_grid_div = out_grid_div
        # 2 channels per scale (depth + density) → CNN → embed_dim
        self.per_scale = nn.ModuleList(
            [_SmallCNN(2, hidden, embed_dim) for _ in self.scales]
        )
        # final mixer over concatenated multi-scale features
        self.fuse = nn.Linear(embed_dim * len(self.scales), embed_dim)
        # zero-init the final projection so untrained model behaves like PPD
        nn.init.zeros_(self.fuse.weight)
        nn.init.zeros_(self.fuse.bias)

    def forward(
        self, sparse_depth: torch.Tensor, sparse_mask: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Returns (tokens, density_per_token).

        tokens: (B, T, embed_dim)
        density_per_token: (B, T, 1) — averaged density across scales, used by
            the prompt gate as a confidence weight.
        """
        # Normalize sparse depth once at the input resolution.
        norm_depth = quantile_log_normalize(sparse_depth, sparse_mask)
        B, _, H, W = sparse_depth.shape
        out_h, out_w = H // self.out_grid_div, W // self.out_grid_div
        feats: list[torch.Tensor] = []
        densities: list[torch.Tensor] = []
        for cnn, k in zip(self.per_scale, self.scales):
            pooled, density = masked_avg_pool(norm_depth, sparse_mask, kernel=k)
            x = torch.cat([pooled, density], dim=1)
            x = cnn(x)
            x = F.interpolate(x, size=(out_h, out_w), mode="bilinear", align_corners=False)
            d = F.interpolate(density, size=(out_h, out_w), mode="bilinear", align_corners=False)
            feats.append(x)
            densities.append(d)
        x = torch.cat(feats, dim=1)  # (B, embed_dim*len(scales), out_h, out_w)
        x = x.flatten(2).transpose(1, 2)  # (B, T, embed_dim*len(scales))
        x = self.fuse(x)  # (B, T, embed_dim)
        density = torch.stack(densities, dim=0).mean(dim=0)  # (B,1,out_h,out_w)
        density = density.flatten(2).transpose(1, 2)  # (B, T, 1)
        return x, density