""" LPD-DiT: PPD's DiT augmented with a sparse-LiDAR prompt path. The base DiT already fuses VFM semantics at the midpoint (block depth/2-1) by calling `proj_fusion(cat([x, semantics], -1))` and upsampling the token grid to stage-2 resolution. Right after that fusion we additionally inject the sparse-prompt tokens through `PromptGate`. The remaining stage-2 blocks then attend over the gated tokens. Only the prompt encoder + gate are new parameters; everything else is identical to the pretrained DiT and can stay frozen. """ from __future__ import annotations from typing import Optional import torch import torch.nn.functional as F from ppd.models.dit import DiT from ppd.lpd.prompt_encoder import SparsePromptEncoder from ppd.lpd.prompt_gate import PromptGate from ppd.lpd.uncertainty_modulation import modulate_density class LPDDiT(DiT): """DiT + sparse-prompt fusion at the midpoint. `forward` accepts the original (x, semantics, timestep) plus optional sparse_depth + sparse_mask. When sparse inputs are None, behavior is identical to the parent DiT (so a checkpoint trained as PPD still runs). """ def __init__( self, in_channels: int = 4, out_channels: int = 1, hidden_size: int = 1024, depth: int = 24, num_heads: int = 16, patch_size: int = 8, mlp_ratio: float = 4.0, prompt_scales: tuple[int, ...] = (4, 8, 16, 32), prompt_hidden: int = 128, ): super().__init__( in_channels=in_channels, out_channels=out_channels, hidden_size=hidden_size, depth=depth, num_heads=num_heads, patch_size=patch_size, mlp_ratio=mlp_ratio, ) # Prompt encoder produces tokens at the stage-2 grid (H/(2p), W/(2p)), # which after the parent DiT's stage-1 → stage-2 reshape equals (H/p, W/p) # for tokens. PPD's stage-2 grid has spatial resolution H/p (p=8 default). self.prompt_scales = tuple(prompt_scales) self.sparse_prompt_encoder = SparsePromptEncoder( scales=self.prompt_scales, embed_dim=hidden_size, out_grid_div=patch_size, hidden=prompt_hidden, ) self.prompt_gate = PromptGate( embed_dim=hidden_size, timestep_dim=hidden_size ) def forward( self, x: torch.Tensor, semantics: torch.Tensor, timestep: torch.Tensor, *, sparse_depth: Optional[torch.Tensor] = None, sparse_mask: Optional[torch.Tensor] = None, kalman_variance: Optional[torch.Tensor] = None, dropout: float = 0.1, ) -> torch.Tensor: N, C, H, W = x.shape if timestep.ndim == 0: timestep = timestep[None] pos0 = pos1 = None if self.rope is not None: pos0 = self.position_getter(N, H // 16, W // 16, device=x.device) pos1 = self.position_getter(N, H // 8, W // 8, device=x.device) x = self.x_embedder(x) t = self.t_embedder(timestep) # (N, D) # Pre-compute prompt tokens at stage-2 grid if sparse inputs provided. prompt_tokens = density_tokens = None if sparse_depth is not None and sparse_mask is not None: prompt_tokens, density_tokens = self.sparse_prompt_encoder( sparse_depth, sparse_mask ) if kalman_variance is not None: density_tokens = modulate_density(density_tokens, kalman_variance) for i, block in enumerate(self.blocks): if i < (self.depth // 2): x = block(x, t, pos0) else: x = block(x, t, pos1) if i == (self.depth // 2) - 1: # Stage-1 → Stage-2 transition: PPD's semantics fusion + reshape. semantics_norm = F.normalize(semantics, dim=-1) x = self.proj_fusion(torch.cat([x, semantics_norm], dim=-1)) p = self.patch_size * 2 D = x.shape[-1] // 4 x = x.reshape(N, H // p, W // p, 2, 2, D) x = torch.einsum("nhwpqc->nchpwq", x) x = x.reshape(N, D, (H // p) * 2, (W // p) * 2) x = x.flatten(2).transpose(1, 2) # New: apply prompt gate at the stage-2 grid, before stage-2 blocks. if prompt_tokens is not None: h2, w2 = (H // p) * 2, (W // p) * 2 if prompt_tokens.shape[1] != x.shape[1]: # Resample prompt tokens to match stage-2 grid in case of mismatch. prompt_h = int(prompt_tokens.shape[1] ** 0.5) prompt_w = prompt_tokens.shape[1] // max(prompt_h, 1) prompt_tokens = F.interpolate( prompt_tokens.transpose(1, 2).reshape( N, D, prompt_h, prompt_w ), size=(h2, w2), mode="bilinear", align_corners=False, ).flatten(2).transpose(1, 2) density_tokens = F.interpolate( density_tokens.transpose(1, 2).reshape( N, 1, prompt_h, prompt_w ), size=(h2, w2), mode="bilinear", align_corners=False, ).flatten(2).transpose(1, 2) x = self.prompt_gate(x, prompt_tokens, density_tokens, t) x = self.final_layer(x, t) x = self.unpatchify(x, height=H, width=W) return x # ------------------------------------------------------------------ # Helpers for partial-loading from a vanilla PPD checkpoint # ------------------------------------------------------------------ def freeze_backbone(self) -> None: """Freeze every parameter that came from the parent DiT. Only the prompt encoder + gate stay trainable, matching paper §3.6: all extensions are inference-time mechanisms or lightweight prompt modules training fewer than 1% of total parameters. """ # Freeze everything first, then re-enable prompt branches for p in self.parameters(): p.requires_grad = False for p in self.sparse_prompt_encoder.parameters(): p.requires_grad = True for p in self.prompt_gate.parameters(): p.requires_grad = True def num_trainable_params(self) -> int: return sum(p.numel() for p in self.parameters() if p.requires_grad) def num_total_params(self) -> int: return sum(p.numel() for p in self.parameters())