| """ |
| Noise-level-conditioned prompt gate. |
| |
| Given semantic tokens (from SP-DiT's stage-1 fusion), sparse-prompt tokens |
| (from `SparsePromptEncoder`), per-token density and a timestep embedding, |
| produce gated joint tokens: |
| |
| s_joint = s_sem + g(p, ρ, t) ⊙ m(s_sem, p, ρ, t) |
| |
| `m` is a mixer MLP, `g` is a sigmoid gate. Both are zero-initialized so the |
| model starts identical to the pretrained PPD. |
| |
| The timestep injection (paper §3.1) lets the gate learn to attend to the |
| semantic prior at high noise and sharpen with sparse depth at low noise. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def _zero_linear(in_features: int, out_features: int, bias: bool = True) -> nn.Linear: |
| layer = nn.Linear(in_features, out_features, bias=bias) |
| nn.init.zeros_(layer.weight) |
| if bias: |
| nn.init.zeros_(layer.bias) |
| return layer |
|
|
|
|
| class PromptGate(nn.Module): |
| """Noise-level-conditioned prompt gating layer. |
| |
| The per-token density ρ is folded into the gate's input so high-confidence |
| (densely-observed) tokens can dominate when useful. Both density and the |
| timestep embedding are broadcast across the token axis. |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim: int = 1024, |
| timestep_dim: int = 1024, |
| density_dim: int = 1, |
| hidden: int = 1024, |
| ): |
| super().__init__() |
| |
| in_features = embed_dim * 2 + density_dim + timestep_dim |
| self.mixer = nn.Sequential( |
| nn.Linear(in_features, hidden), |
| nn.GELU(), |
| _zero_linear(hidden, embed_dim), |
| ) |
| |
| gate_in = embed_dim + density_dim + timestep_dim |
| self.gate = nn.Sequential( |
| nn.Linear(gate_in, hidden), |
| nn.GELU(), |
| _zero_linear(hidden, embed_dim), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward( |
| self, |
| s_sem: torch.Tensor, |
| p: torch.Tensor, |
| density: torch.Tensor, |
| t_emb: torch.Tensor, |
| ) -> torch.Tensor: |
| """Inputs: s_sem (B,T,D), p (B,T,D), density (B,T,1), t_emb (B,D_t). |
| |
| Returns joint tokens (B,T,D). |
| """ |
| B, T, _ = s_sem.shape |
| t_broad = t_emb[:, None, :].expand(B, T, -1) |
| mixer_in = torch.cat([s_sem, p, density, t_broad], dim=-1) |
| gate_in = torch.cat([p, density, t_broad], dim=-1) |
| delta = self.mixer(mixer_in) |
| g = self.gate(gate_in) |
| return s_sem + g * delta |
|
|