| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class ModLN(nn.Module): |
| | """ |
| | Modulation with adaLN. |
| | |
| | References: |
| | DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101 |
| | """ |
| |
|
| | def __init__(self, inner_dim: int, mod_dim: int, eps: float): |
| | super().__init__() |
| | self.norm = nn.LayerNorm(inner_dim, eps=eps) |
| | self.mlp = nn.Sequential( |
| | nn.SiLU(), |
| | nn.Linear(mod_dim, inner_dim * 2), |
| | ) |
| |
|
| | @staticmethod |
| | def modulate(x, shift, scale): |
| | |
| | |
| | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
| |
|
| | def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: |
| | shift, scale = self.mlp(mod).chunk(2, dim=-1) |
| | return self.modulate(self.norm(x), shift, scale) |
| |
|