Spaces:
Running
Running
| # Adapted from: https://github.com/facebookresearch/DiT | |
| import torch | |
| from torch import nn | |
| class AdaLNZero(nn.Module): | |
| """ | |
| Adaptive Layer Normalization Zero (AdaLNZero) module. | |
| Combines LayerNorm with adaptive conditioning to produce shift, scale, and gate values. | |
| The gate is used to scale features before residual connection. | |
| Args: | |
| dim: Feature dimension | |
| condition_dim: Conditioning dimension | |
| eps: LayerNorm epsilon | |
| return_gate: If True, returns gate value for scaling. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| condition_dim: int, | |
| eps: float = 1e-5, | |
| return_gate: bool = True, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.condition_dim = condition_dim | |
| self.return_gate = return_gate | |
| # LayerNorm without learnable parameters | |
| self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) | |
| # Conditioning network: condition -> shift, scale, gate | |
| output_dim = 3 * dim if return_gate else 2 * dim | |
| self.condition_proj = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(condition_dim, output_dim), | |
| ) | |
| # Initialize to zero for stable training | |
| nn.init.zeros_(self.condition_proj[1].weight) | |
| nn.init.zeros_(self.condition_proj[1].bias) | |
| def forward(self, x: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor] | None: | |
| """ | |
| Args: | |
| x: Input tensor of shape (B, L, dim) | |
| condition: Conditioning tensor of shape (B, L, condition_dim) or (B, 1, condition_dim) | |
| Returns: | |
| modulated_x: Normalized and modulated features | |
| gate: Gate values for scaling (None if return_gate=False) | |
| """ | |
| x_norm = self.norm(x) | |
| condition_params = self.condition_proj(condition) | |
| if self.return_gate: | |
| shift, scale, gate = condition_params.chunk(3, dim=-1) | |
| else: | |
| shift, scale = condition_params.chunk(2, dim=-1) | |
| gate = None | |
| modulated_x = x_norm * (1 + scale) + shift | |
| return modulated_x, gate | |