File size: 6,852 Bytes
436b829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """
LPD-DiT: PPD's DiT augmented with a sparse-LiDAR prompt path.
The base DiT already fuses VFM semantics at the midpoint (block depth/2-1) by
calling `proj_fusion(cat([x, semantics], -1))` and upsampling the token grid
to stage-2 resolution. Right after that fusion we additionally inject the
sparse-prompt tokens through `PromptGate`. The remaining stage-2 blocks then
attend over the gated tokens.
Only the prompt encoder + gate are new parameters; everything else is
identical to the pretrained DiT and can stay frozen.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn.functional as F
from ppd.models.dit import DiT
from ppd.lpd.prompt_encoder import SparsePromptEncoder
from ppd.lpd.prompt_gate import PromptGate
from ppd.lpd.uncertainty_modulation import modulate_density
class LPDDiT(DiT):
"""DiT + sparse-prompt fusion at the midpoint.
`forward` accepts the original (x, semantics, timestep) plus optional
sparse_depth + sparse_mask. When sparse inputs are None, behavior is
identical to the parent DiT (so a checkpoint trained as PPD still runs).
"""
def __init__(
self,
in_channels: int = 4,
out_channels: int = 1,
hidden_size: int = 1024,
depth: int = 24,
num_heads: int = 16,
patch_size: int = 8,
mlp_ratio: float = 4.0,
prompt_scales: tuple[int, ...] = (4, 8, 16, 32),
prompt_hidden: int = 128,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
hidden_size=hidden_size,
depth=depth,
num_heads=num_heads,
patch_size=patch_size,
mlp_ratio=mlp_ratio,
)
# Prompt encoder produces tokens at the stage-2 grid (H/(2p), W/(2p)),
# which after the parent DiT's stage-1 → stage-2 reshape equals (H/p, W/p)
# for tokens. PPD's stage-2 grid has spatial resolution H/p (p=8 default).
self.prompt_scales = tuple(prompt_scales)
self.sparse_prompt_encoder = SparsePromptEncoder(
scales=self.prompt_scales,
embed_dim=hidden_size,
out_grid_div=patch_size,
hidden=prompt_hidden,
)
self.prompt_gate = PromptGate(
embed_dim=hidden_size, timestep_dim=hidden_size
)
def forward(
self,
x: torch.Tensor,
semantics: torch.Tensor,
timestep: torch.Tensor,
*,
sparse_depth: Optional[torch.Tensor] = None,
sparse_mask: Optional[torch.Tensor] = None,
kalman_variance: Optional[torch.Tensor] = None,
dropout: float = 0.1,
) -> torch.Tensor:
N, C, H, W = x.shape
if timestep.ndim == 0:
timestep = timestep[None]
pos0 = pos1 = None
if self.rope is not None:
pos0 = self.position_getter(N, H // 16, W // 16, device=x.device)
pos1 = self.position_getter(N, H // 8, W // 8, device=x.device)
x = self.x_embedder(x)
t = self.t_embedder(timestep) # (N, D)
# Pre-compute prompt tokens at stage-2 grid if sparse inputs provided.
prompt_tokens = density_tokens = None
if sparse_depth is not None and sparse_mask is not None:
prompt_tokens, density_tokens = self.sparse_prompt_encoder(
sparse_depth, sparse_mask
)
if kalman_variance is not None:
density_tokens = modulate_density(density_tokens, kalman_variance)
for i, block in enumerate(self.blocks):
if i < (self.depth // 2):
x = block(x, t, pos0)
else:
x = block(x, t, pos1)
if i == (self.depth // 2) - 1:
# Stage-1 → Stage-2 transition: PPD's semantics fusion + reshape.
semantics_norm = F.normalize(semantics, dim=-1)
x = self.proj_fusion(torch.cat([x, semantics_norm], dim=-1))
p = self.patch_size * 2
D = x.shape[-1] // 4
x = x.reshape(N, H // p, W // p, 2, 2, D)
x = torch.einsum("nhwpqc->nchpwq", x)
x = x.reshape(N, D, (H // p) * 2, (W // p) * 2)
x = x.flatten(2).transpose(1, 2)
# New: apply prompt gate at the stage-2 grid, before stage-2 blocks.
if prompt_tokens is not None:
h2, w2 = (H // p) * 2, (W // p) * 2
if prompt_tokens.shape[1] != x.shape[1]:
# Resample prompt tokens to match stage-2 grid in case of mismatch.
prompt_h = int(prompt_tokens.shape[1] ** 0.5)
prompt_w = prompt_tokens.shape[1] // max(prompt_h, 1)
prompt_tokens = F.interpolate(
prompt_tokens.transpose(1, 2).reshape(
N, D, prompt_h, prompt_w
),
size=(h2, w2),
mode="bilinear",
align_corners=False,
).flatten(2).transpose(1, 2)
density_tokens = F.interpolate(
density_tokens.transpose(1, 2).reshape(
N, 1, prompt_h, prompt_w
),
size=(h2, w2),
mode="bilinear",
align_corners=False,
).flatten(2).transpose(1, 2)
x = self.prompt_gate(x, prompt_tokens, density_tokens, t)
x = self.final_layer(x, t)
x = self.unpatchify(x, height=H, width=W)
return x
# ------------------------------------------------------------------
# Helpers for partial-loading from a vanilla PPD checkpoint
# ------------------------------------------------------------------
def freeze_backbone(self) -> None:
"""Freeze every parameter that came from the parent DiT.
Only the prompt encoder + gate stay trainable, matching paper §3.6:
all extensions are inference-time mechanisms or lightweight prompt
modules training fewer than 1% of total parameters.
"""
# Freeze everything first, then re-enable prompt branches
for p in self.parameters():
p.requires_grad = False
for p in self.sparse_prompt_encoder.parameters():
p.requires_grad = True
for p in self.prompt_gate.parameters():
p.requires_grad = True
def num_trainable_params(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def num_total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
|