LiDAR-Perfect-Depth / code /ppd /lpd /lpd_dit.py
chenming-wu's picture
code
436b829 verified
"""
LPD-DiT: PPD's DiT augmented with a sparse-LiDAR prompt path.
The base DiT already fuses VFM semantics at the midpoint (block depth/2-1) by
calling `proj_fusion(cat([x, semantics], -1))` and upsampling the token grid
to stage-2 resolution. Right after that fusion we additionally inject the
sparse-prompt tokens through `PromptGate`. The remaining stage-2 blocks then
attend over the gated tokens.
Only the prompt encoder + gate are new parameters; everything else is
identical to the pretrained DiT and can stay frozen.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn.functional as F
from ppd.models.dit import DiT
from ppd.lpd.prompt_encoder import SparsePromptEncoder
from ppd.lpd.prompt_gate import PromptGate
from ppd.lpd.uncertainty_modulation import modulate_density
class LPDDiT(DiT):
"""DiT + sparse-prompt fusion at the midpoint.
`forward` accepts the original (x, semantics, timestep) plus optional
sparse_depth + sparse_mask. When sparse inputs are None, behavior is
identical to the parent DiT (so a checkpoint trained as PPD still runs).
"""
def __init__(
self,
in_channels: int = 4,
out_channels: int = 1,
hidden_size: int = 1024,
depth: int = 24,
num_heads: int = 16,
patch_size: int = 8,
mlp_ratio: float = 4.0,
prompt_scales: tuple[int, ...] = (4, 8, 16, 32),
prompt_hidden: int = 128,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
hidden_size=hidden_size,
depth=depth,
num_heads=num_heads,
patch_size=patch_size,
mlp_ratio=mlp_ratio,
)
# Prompt encoder produces tokens at the stage-2 grid (H/(2p), W/(2p)),
# which after the parent DiT's stage-1 → stage-2 reshape equals (H/p, W/p)
# for tokens. PPD's stage-2 grid has spatial resolution H/p (p=8 default).
self.prompt_scales = tuple(prompt_scales)
self.sparse_prompt_encoder = SparsePromptEncoder(
scales=self.prompt_scales,
embed_dim=hidden_size,
out_grid_div=patch_size,
hidden=prompt_hidden,
)
self.prompt_gate = PromptGate(
embed_dim=hidden_size, timestep_dim=hidden_size
)
def forward(
self,
x: torch.Tensor,
semantics: torch.Tensor,
timestep: torch.Tensor,
*,
sparse_depth: Optional[torch.Tensor] = None,
sparse_mask: Optional[torch.Tensor] = None,
kalman_variance: Optional[torch.Tensor] = None,
dropout: float = 0.1,
) -> torch.Tensor:
N, C, H, W = x.shape
if timestep.ndim == 0:
timestep = timestep[None]
pos0 = pos1 = None
if self.rope is not None:
pos0 = self.position_getter(N, H // 16, W // 16, device=x.device)
pos1 = self.position_getter(N, H // 8, W // 8, device=x.device)
x = self.x_embedder(x)
t = self.t_embedder(timestep) # (N, D)
# Pre-compute prompt tokens at stage-2 grid if sparse inputs provided.
prompt_tokens = density_tokens = None
if sparse_depth is not None and sparse_mask is not None:
prompt_tokens, density_tokens = self.sparse_prompt_encoder(
sparse_depth, sparse_mask
)
if kalman_variance is not None:
density_tokens = modulate_density(density_tokens, kalman_variance)
for i, block in enumerate(self.blocks):
if i < (self.depth // 2):
x = block(x, t, pos0)
else:
x = block(x, t, pos1)
if i == (self.depth // 2) - 1:
# Stage-1 → Stage-2 transition: PPD's semantics fusion + reshape.
semantics_norm = F.normalize(semantics, dim=-1)
x = self.proj_fusion(torch.cat([x, semantics_norm], dim=-1))
p = self.patch_size * 2
D = x.shape[-1] // 4
x = x.reshape(N, H // p, W // p, 2, 2, D)
x = torch.einsum("nhwpqc->nchpwq", x)
x = x.reshape(N, D, (H // p) * 2, (W // p) * 2)
x = x.flatten(2).transpose(1, 2)
# New: apply prompt gate at the stage-2 grid, before stage-2 blocks.
if prompt_tokens is not None:
h2, w2 = (H // p) * 2, (W // p) * 2
if prompt_tokens.shape[1] != x.shape[1]:
# Resample prompt tokens to match stage-2 grid in case of mismatch.
prompt_h = int(prompt_tokens.shape[1] ** 0.5)
prompt_w = prompt_tokens.shape[1] // max(prompt_h, 1)
prompt_tokens = F.interpolate(
prompt_tokens.transpose(1, 2).reshape(
N, D, prompt_h, prompt_w
),
size=(h2, w2),
mode="bilinear",
align_corners=False,
).flatten(2).transpose(1, 2)
density_tokens = F.interpolate(
density_tokens.transpose(1, 2).reshape(
N, 1, prompt_h, prompt_w
),
size=(h2, w2),
mode="bilinear",
align_corners=False,
).flatten(2).transpose(1, 2)
x = self.prompt_gate(x, prompt_tokens, density_tokens, t)
x = self.final_layer(x, t)
x = self.unpatchify(x, height=H, width=W)
return x
# ------------------------------------------------------------------
# Helpers for partial-loading from a vanilla PPD checkpoint
# ------------------------------------------------------------------
def freeze_backbone(self) -> None:
"""Freeze every parameter that came from the parent DiT.
Only the prompt encoder + gate stay trainable, matching paper §3.6:
all extensions are inference-time mechanisms or lightweight prompt
modules training fewer than 1% of total parameters.
"""
# Freeze everything first, then re-enable prompt branches
for p in self.parameters():
p.requires_grad = False
for p in self.sparse_prompt_encoder.parameters():
p.requires_grad = True
for p in self.prompt_gate.parameters():
p.requires_grad = True
def num_trainable_params(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def num_total_params(self) -> int:
return sum(p.numel() for p in self.parameters())