""" Noise-level-conditioned prompt gate. Given semantic tokens (from SP-DiT's stage-1 fusion), sparse-prompt tokens (from `SparsePromptEncoder`), per-token density and a timestep embedding, produce gated joint tokens: s_joint = s_sem + g(p, ρ, t) ⊙ m(s_sem, p, ρ, t) `m` is a mixer MLP, `g` is a sigmoid gate. Both are zero-initialized so the model starts identical to the pretrained PPD. The timestep injection (paper §3.1) lets the gate learn to attend to the semantic prior at high noise and sharpen with sparse depth at low noise. """ from __future__ import annotations import torch import torch.nn as nn def _zero_linear(in_features: int, out_features: int, bias: bool = True) -> nn.Linear: layer = nn.Linear(in_features, out_features, bias=bias) nn.init.zeros_(layer.weight) if bias: nn.init.zeros_(layer.bias) return layer class PromptGate(nn.Module): """Noise-level-conditioned prompt gating layer. The per-token density ρ is folded into the gate's input so high-confidence (densely-observed) tokens can dominate when useful. Both density and the timestep embedding are broadcast across the token axis. """ def __init__( self, embed_dim: int = 1024, timestep_dim: int = 1024, density_dim: int = 1, hidden: int = 1024, ): super().__init__() # input: [s_sem, p, ρ, t] concatenated along feature axis. in_features = embed_dim * 2 + density_dim + timestep_dim self.mixer = nn.Sequential( nn.Linear(in_features, hidden), nn.GELU(), _zero_linear(hidden, embed_dim), ) # gate input: [p, ρ, t] gate_in = embed_dim + density_dim + timestep_dim self.gate = nn.Sequential( nn.Linear(gate_in, hidden), nn.GELU(), _zero_linear(hidden, embed_dim), nn.Sigmoid(), ) def forward( self, s_sem: torch.Tensor, p: torch.Tensor, density: torch.Tensor, t_emb: torch.Tensor, ) -> torch.Tensor: """Inputs: s_sem (B,T,D), p (B,T,D), density (B,T,1), t_emb (B,D_t). Returns joint tokens (B,T,D). """ B, T, _ = s_sem.shape t_broad = t_emb[:, None, :].expand(B, T, -1) mixer_in = torch.cat([s_sem, p, density, t_broad], dim=-1) gate_in = torch.cat([p, density, t_broad], dim=-1) delta = self.mixer(mixer_in) g = self.gate(gate_in) return s_sem + g * delta