| |
|
|
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class AdaLNZero(nn.Module): |
| """ |
| Adaptive Layer Normalization Zero (AdaLNZero) module. |
| |
| Combines LayerNorm with adaptive conditioning to produce shift, scale, and gate values. |
| The gate is used to scale features before residual connection. |
| |
| Args: |
| dim: Feature dimension |
| condition_dim: Conditioning dimension |
| eps: LayerNorm epsilon |
| return_gate: If True, returns gate value for scaling. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| condition_dim: int, |
| eps: float = 1e-5, |
| return_gate: bool = True, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.condition_dim = condition_dim |
| self.return_gate = return_gate |
|
|
| |
| self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
|
|
| |
| output_dim = 3 * dim if return_gate else 2 * dim |
| self.condition_proj = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(condition_dim, output_dim), |
| ) |
|
|
| |
| nn.init.zeros_(self.condition_proj[1].weight) |
| nn.init.zeros_(self.condition_proj[1].bias) |
|
|
| def forward(self, x: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor] | None: |
| """ |
| Args: |
| x: Input tensor of shape (B, L, dim) |
| condition: Conditioning tensor of shape (B, L, condition_dim) or (B, 1, condition_dim) |
| |
| Returns: |
| modulated_x: Normalized and modulated features |
| gate: Gate values for scaling (None if return_gate=False) |
| """ |
| x_norm = self.norm(x) |
| condition_params = self.condition_proj(condition) |
|
|
| if self.return_gate: |
| shift, scale, gate = condition_params.chunk(3, dim=-1) |
| else: |
| shift, scale = condition_params.chunk(2, dim=-1) |
| gate = None |
|
|
| modulated_x = x_norm * (1 + scale) + shift |
| return modulated_x, gate |
|
|