File size: 6,852 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
LPD-DiT: PPD's DiT augmented with a sparse-LiDAR prompt path.

The base DiT already fuses VFM semantics at the midpoint (block depth/2-1) by
calling `proj_fusion(cat([x, semantics], -1))` and upsampling the token grid
to stage-2 resolution. Right after that fusion we additionally inject the
sparse-prompt tokens through `PromptGate`. The remaining stage-2 blocks then
attend over the gated tokens.

Only the prompt encoder + gate are new parameters; everything else is
identical to the pretrained DiT and can stay frozen.
"""
from __future__ import annotations

from typing import Optional

import torch
import torch.nn.functional as F

from ppd.models.dit import DiT
from ppd.lpd.prompt_encoder import SparsePromptEncoder
from ppd.lpd.prompt_gate import PromptGate
from ppd.lpd.uncertainty_modulation import modulate_density


class LPDDiT(DiT):
    """DiT + sparse-prompt fusion at the midpoint.

    `forward` accepts the original (x, semantics, timestep) plus optional
    sparse_depth + sparse_mask. When sparse inputs are None, behavior is
    identical to the parent DiT (so a checkpoint trained as PPD still runs).
    """

    def __init__(
        self,
        in_channels: int = 4,
        out_channels: int = 1,
        hidden_size: int = 1024,
        depth: int = 24,
        num_heads: int = 16,
        patch_size: int = 8,
        mlp_ratio: float = 4.0,
        prompt_scales: tuple[int, ...] = (4, 8, 16, 32),
        prompt_hidden: int = 128,
    ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            hidden_size=hidden_size,
            depth=depth,
            num_heads=num_heads,
            patch_size=patch_size,
            mlp_ratio=mlp_ratio,
        )
        # Prompt encoder produces tokens at the stage-2 grid (H/(2p), W/(2p)),
        # which after the parent DiT's stage-1 → stage-2 reshape equals (H/p, W/p)
        # for tokens. PPD's stage-2 grid has spatial resolution H/p (p=8 default).
        self.prompt_scales = tuple(prompt_scales)
        self.sparse_prompt_encoder = SparsePromptEncoder(
            scales=self.prompt_scales,
            embed_dim=hidden_size,
            out_grid_div=patch_size,
            hidden=prompt_hidden,
        )
        self.prompt_gate = PromptGate(
            embed_dim=hidden_size, timestep_dim=hidden_size
        )

    def forward(
        self,
        x: torch.Tensor,
        semantics: torch.Tensor,
        timestep: torch.Tensor,
        *,
        sparse_depth: Optional[torch.Tensor] = None,
        sparse_mask: Optional[torch.Tensor] = None,
        kalman_variance: Optional[torch.Tensor] = None,
        dropout: float = 0.1,
    ) -> torch.Tensor:
        N, C, H, W = x.shape
        if timestep.ndim == 0:
            timestep = timestep[None]

        pos0 = pos1 = None
        if self.rope is not None:
            pos0 = self.position_getter(N, H // 16, W // 16, device=x.device)
            pos1 = self.position_getter(N, H // 8, W // 8, device=x.device)

        x = self.x_embedder(x)
        t = self.t_embedder(timestep)  # (N, D)

        # Pre-compute prompt tokens at stage-2 grid if sparse inputs provided.
        prompt_tokens = density_tokens = None
        if sparse_depth is not None and sparse_mask is not None:
            prompt_tokens, density_tokens = self.sparse_prompt_encoder(
                sparse_depth, sparse_mask
            )
            if kalman_variance is not None:
                density_tokens = modulate_density(density_tokens, kalman_variance)

        for i, block in enumerate(self.blocks):
            if i < (self.depth // 2):
                x = block(x, t, pos0)
            else:
                x = block(x, t, pos1)

            if i == (self.depth // 2) - 1:
                # Stage-1 → Stage-2 transition: PPD's semantics fusion + reshape.
                semantics_norm = F.normalize(semantics, dim=-1)
                x = self.proj_fusion(torch.cat([x, semantics_norm], dim=-1))
                p = self.patch_size * 2
                D = x.shape[-1] // 4
                x = x.reshape(N, H // p, W // p, 2, 2, D)
                x = torch.einsum("nhwpqc->nchpwq", x)
                x = x.reshape(N, D, (H // p) * 2, (W // p) * 2)
                x = x.flatten(2).transpose(1, 2)

                # New: apply prompt gate at the stage-2 grid, before stage-2 blocks.
                if prompt_tokens is not None:
                    h2, w2 = (H // p) * 2, (W // p) * 2
                    if prompt_tokens.shape[1] != x.shape[1]:
                        # Resample prompt tokens to match stage-2 grid in case of mismatch.
                        prompt_h = int(prompt_tokens.shape[1] ** 0.5)
                        prompt_w = prompt_tokens.shape[1] // max(prompt_h, 1)
                        prompt_tokens = F.interpolate(
                            prompt_tokens.transpose(1, 2).reshape(
                                N, D, prompt_h, prompt_w
                            ),
                            size=(h2, w2),
                            mode="bilinear",
                            align_corners=False,
                        ).flatten(2).transpose(1, 2)
                        density_tokens = F.interpolate(
                            density_tokens.transpose(1, 2).reshape(
                                N, 1, prompt_h, prompt_w
                            ),
                            size=(h2, w2),
                            mode="bilinear",
                            align_corners=False,
                        ).flatten(2).transpose(1, 2)
                    x = self.prompt_gate(x, prompt_tokens, density_tokens, t)

        x = self.final_layer(x, t)
        x = self.unpatchify(x, height=H, width=W)
        return x

    # ------------------------------------------------------------------
    # Helpers for partial-loading from a vanilla PPD checkpoint
    # ------------------------------------------------------------------
    def freeze_backbone(self) -> None:
        """Freeze every parameter that came from the parent DiT.

        Only the prompt encoder + gate stay trainable, matching paper §3.6:
        all extensions are inference-time mechanisms or lightweight prompt
        modules training fewer than 1% of total parameters.
        """
        # Freeze everything first, then re-enable prompt branches
        for p in self.parameters():
            p.requires_grad = False
        for p in self.sparse_prompt_encoder.parameters():
            p.requires_grad = True
        for p in self.prompt_gate.parameters():
            p.requires_grad = True

    def num_trainable_params(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def num_total_params(self) -> int:
        return sum(p.numel() for p in self.parameters())