File size: 2,157 Bytes
2cba492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Adapted from: https://github.com/facebookresearch/DiT


import torch
from torch import nn


class AdaLNZero(nn.Module):
    """
    Adaptive Layer Normalization Zero (AdaLNZero) module.

    Combines LayerNorm with adaptive conditioning to produce shift, scale, and gate values.
    The gate is used to scale features before residual connection.

    Args:
        dim: Feature dimension
        condition_dim: Conditioning dimension
        eps: LayerNorm epsilon
        return_gate: If True, returns gate value for scaling.
    """

    def __init__(
        self,
        dim: int,
        condition_dim: int,
        eps: float = 1e-5,
        return_gate: bool = True,
    ):
        super().__init__()
        self.dim = dim
        self.condition_dim = condition_dim
        self.return_gate = return_gate

        # LayerNorm without learnable parameters
        self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)

        # Conditioning network: condition -> shift, scale, gate
        output_dim = 3 * dim if return_gate else 2 * dim
        self.condition_proj = nn.Sequential(
            nn.SiLU(),
            nn.Linear(condition_dim, output_dim),
        )

        # Initialize to zero for stable training
        nn.init.zeros_(self.condition_proj[1].weight)
        nn.init.zeros_(self.condition_proj[1].bias)

    def forward(self, x: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor] | None:
        """
        Args:
            x: Input tensor of shape (B, L, dim)
            condition: Conditioning tensor of shape (B, L, condition_dim) or (B, 1, condition_dim)

        Returns:
            modulated_x: Normalized and modulated features
            gate: Gate values for scaling (None if return_gate=False)
        """
        x_norm = self.norm(x)
        condition_params = self.condition_proj(condition)

        if self.return_gate:
            shift, scale, gate = condition_params.chunk(3, dim=-1)
        else:
            shift, scale = condition_params.chunk(2, dim=-1)
            gate = None

        modulated_x = x_norm * (1 + scale) + shift
        return modulated_x, gate