Dalzymodderever
Intial Commit
2cba492
# Adapted from: https://github.com/facebookresearch/DiT
import torch
from torch import nn
class AdaLNZero(nn.Module):
"""
Adaptive Layer Normalization Zero (AdaLNZero) module.
Combines LayerNorm with adaptive conditioning to produce shift, scale, and gate values.
The gate is used to scale features before residual connection.
Args:
dim: Feature dimension
condition_dim: Conditioning dimension
eps: LayerNorm epsilon
return_gate: If True, returns gate value for scaling.
"""
def __init__(
self,
dim: int,
condition_dim: int,
eps: float = 1e-5,
return_gate: bool = True,
):
super().__init__()
self.dim = dim
self.condition_dim = condition_dim
self.return_gate = return_gate
# LayerNorm without learnable parameters
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
# Conditioning network: condition -> shift, scale, gate
output_dim = 3 * dim if return_gate else 2 * dim
self.condition_proj = nn.Sequential(
nn.SiLU(),
nn.Linear(condition_dim, output_dim),
)
# Initialize to zero for stable training
nn.init.zeros_(self.condition_proj[1].weight)
nn.init.zeros_(self.condition_proj[1].bias)
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor] | None:
"""
Args:
x: Input tensor of shape (B, L, dim)
condition: Conditioning tensor of shape (B, L, condition_dim) or (B, 1, condition_dim)
Returns:
modulated_x: Normalized and modulated features
gate: Gate values for scaling (None if return_gate=False)
"""
x_norm = self.norm(x)
condition_params = self.condition_proj(condition)
if self.return_gate:
shift, scale, gate = condition_params.chunk(3, dim=-1)
else:
shift, scale = condition_params.chunk(2, dim=-1)
gate = None
modulated_x = x_norm * (1 + scale) + shift
return modulated_x, gate