LiDAR-Perfect-Depth / code /ppd /lpd /prompt_gate.py
chenming-wu's picture
code
436b829 verified
"""
Noise-level-conditioned prompt gate.
Given semantic tokens (from SP-DiT's stage-1 fusion), sparse-prompt tokens
(from `SparsePromptEncoder`), per-token density and a timestep embedding,
produce gated joint tokens:
s_joint = s_sem + g(p, ρ, t) ⊙ m(s_sem, p, ρ, t)
`m` is a mixer MLP, `g` is a sigmoid gate. Both are zero-initialized so the
model starts identical to the pretrained PPD.
The timestep injection (paper §3.1) lets the gate learn to attend to the
semantic prior at high noise and sharpen with sparse depth at low noise.
"""
from __future__ import annotations
import torch
import torch.nn as nn
def _zero_linear(in_features: int, out_features: int, bias: bool = True) -> nn.Linear:
layer = nn.Linear(in_features, out_features, bias=bias)
nn.init.zeros_(layer.weight)
if bias:
nn.init.zeros_(layer.bias)
return layer
class PromptGate(nn.Module):
"""Noise-level-conditioned prompt gating layer.
The per-token density ρ is folded into the gate's input so high-confidence
(densely-observed) tokens can dominate when useful. Both density and the
timestep embedding are broadcast across the token axis.
"""
def __init__(
self,
embed_dim: int = 1024,
timestep_dim: int = 1024,
density_dim: int = 1,
hidden: int = 1024,
):
super().__init__()
# input: [s_sem, p, ρ, t] concatenated along feature axis.
in_features = embed_dim * 2 + density_dim + timestep_dim
self.mixer = nn.Sequential(
nn.Linear(in_features, hidden),
nn.GELU(),
_zero_linear(hidden, embed_dim),
)
# gate input: [p, ρ, t]
gate_in = embed_dim + density_dim + timestep_dim
self.gate = nn.Sequential(
nn.Linear(gate_in, hidden),
nn.GELU(),
_zero_linear(hidden, embed_dim),
nn.Sigmoid(),
)
def forward(
self,
s_sem: torch.Tensor,
p: torch.Tensor,
density: torch.Tensor,
t_emb: torch.Tensor,
) -> torch.Tensor:
"""Inputs: s_sem (B,T,D), p (B,T,D), density (B,T,1), t_emb (B,D_t).
Returns joint tokens (B,T,D).
"""
B, T, _ = s_sem.shape
t_broad = t_emb[:, None, :].expand(B, T, -1)
mixer_in = torch.cat([s_sem, p, density, t_broad], dim=-1)
gate_in = torch.cat([p, density, t_broad], dim=-1)
delta = self.mixer(mixer_in)
g = self.gate(gate_in)
return s_sem + g * delta