| """ |
| LPD-DiT: PPD's DiT augmented with a sparse-LiDAR prompt path. |
| |
| The base DiT already fuses VFM semantics at the midpoint (block depth/2-1) by |
| calling `proj_fusion(cat([x, semantics], -1))` and upsampling the token grid |
| to stage-2 resolution. Right after that fusion we additionally inject the |
| sparse-prompt tokens through `PromptGate`. The remaining stage-2 blocks then |
| attend over the gated tokens. |
| |
| Only the prompt encoder + gate are new parameters; everything else is |
| identical to the pretrained DiT and can stay frozen. |
| """ |
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from ppd.models.dit import DiT |
| from ppd.lpd.prompt_encoder import SparsePromptEncoder |
| from ppd.lpd.prompt_gate import PromptGate |
| from ppd.lpd.uncertainty_modulation import modulate_density |
|
|
|
|
| class LPDDiT(DiT): |
| """DiT + sparse-prompt fusion at the midpoint. |
| |
| `forward` accepts the original (x, semantics, timestep) plus optional |
| sparse_depth + sparse_mask. When sparse inputs are None, behavior is |
| identical to the parent DiT (so a checkpoint trained as PPD still runs). |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int = 4, |
| out_channels: int = 1, |
| hidden_size: int = 1024, |
| depth: int = 24, |
| num_heads: int = 16, |
| patch_size: int = 8, |
| mlp_ratio: float = 4.0, |
| prompt_scales: tuple[int, ...] = (4, 8, 16, 32), |
| prompt_hidden: int = 128, |
| ): |
| super().__init__( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| hidden_size=hidden_size, |
| depth=depth, |
| num_heads=num_heads, |
| patch_size=patch_size, |
| mlp_ratio=mlp_ratio, |
| ) |
| |
| |
| |
| self.prompt_scales = tuple(prompt_scales) |
| self.sparse_prompt_encoder = SparsePromptEncoder( |
| scales=self.prompt_scales, |
| embed_dim=hidden_size, |
| out_grid_div=patch_size, |
| hidden=prompt_hidden, |
| ) |
| self.prompt_gate = PromptGate( |
| embed_dim=hidden_size, timestep_dim=hidden_size |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| semantics: torch.Tensor, |
| timestep: torch.Tensor, |
| *, |
| sparse_depth: Optional[torch.Tensor] = None, |
| sparse_mask: Optional[torch.Tensor] = None, |
| kalman_variance: Optional[torch.Tensor] = None, |
| dropout: float = 0.1, |
| ) -> torch.Tensor: |
| N, C, H, W = x.shape |
| if timestep.ndim == 0: |
| timestep = timestep[None] |
|
|
| pos0 = pos1 = None |
| if self.rope is not None: |
| pos0 = self.position_getter(N, H // 16, W // 16, device=x.device) |
| pos1 = self.position_getter(N, H // 8, W // 8, device=x.device) |
|
|
| x = self.x_embedder(x) |
| t = self.t_embedder(timestep) |
|
|
| |
| prompt_tokens = density_tokens = None |
| if sparse_depth is not None and sparse_mask is not None: |
| prompt_tokens, density_tokens = self.sparse_prompt_encoder( |
| sparse_depth, sparse_mask |
| ) |
| if kalman_variance is not None: |
| density_tokens = modulate_density(density_tokens, kalman_variance) |
|
|
| for i, block in enumerate(self.blocks): |
| if i < (self.depth // 2): |
| x = block(x, t, pos0) |
| else: |
| x = block(x, t, pos1) |
|
|
| if i == (self.depth // 2) - 1: |
| |
| semantics_norm = F.normalize(semantics, dim=-1) |
| x = self.proj_fusion(torch.cat([x, semantics_norm], dim=-1)) |
| p = self.patch_size * 2 |
| D = x.shape[-1] // 4 |
| x = x.reshape(N, H // p, W // p, 2, 2, D) |
| x = torch.einsum("nhwpqc->nchpwq", x) |
| x = x.reshape(N, D, (H // p) * 2, (W // p) * 2) |
| x = x.flatten(2).transpose(1, 2) |
|
|
| |
| if prompt_tokens is not None: |
| h2, w2 = (H // p) * 2, (W // p) * 2 |
| if prompt_tokens.shape[1] != x.shape[1]: |
| |
| prompt_h = int(prompt_tokens.shape[1] ** 0.5) |
| prompt_w = prompt_tokens.shape[1] // max(prompt_h, 1) |
| prompt_tokens = F.interpolate( |
| prompt_tokens.transpose(1, 2).reshape( |
| N, D, prompt_h, prompt_w |
| ), |
| size=(h2, w2), |
| mode="bilinear", |
| align_corners=False, |
| ).flatten(2).transpose(1, 2) |
| density_tokens = F.interpolate( |
| density_tokens.transpose(1, 2).reshape( |
| N, 1, prompt_h, prompt_w |
| ), |
| size=(h2, w2), |
| mode="bilinear", |
| align_corners=False, |
| ).flatten(2).transpose(1, 2) |
| x = self.prompt_gate(x, prompt_tokens, density_tokens, t) |
|
|
| x = self.final_layer(x, t) |
| x = self.unpatchify(x, height=H, width=W) |
| return x |
|
|
| |
| |
| |
| def freeze_backbone(self) -> None: |
| """Freeze every parameter that came from the parent DiT. |
| |
| Only the prompt encoder + gate stay trainable, matching paper §3.6: |
| all extensions are inference-time mechanisms or lightweight prompt |
| modules training fewer than 1% of total parameters. |
| """ |
| |
| for p in self.parameters(): |
| p.requires_grad = False |
| for p in self.sparse_prompt_encoder.parameters(): |
| p.requires_grad = True |
| for p in self.prompt_gate.parameters(): |
| p.requires_grad = True |
|
|
| def num_trainable_params(self) -> int: |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
| def num_total_params(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|