|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
class ModLN(nn.Module):
|
|
|
"""
|
|
|
Modulation with adaLN.
|
|
|
|
|
|
References:
|
|
|
DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
|
|
|
"""
|
|
|
def __init__(self, inner_dim: int, mod_dim: int, eps: float):
|
|
|
super().__init__()
|
|
|
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
|
|
self.mlp = nn.Sequential(
|
|
|
nn.SiLU(),
|
|
|
nn.Linear(mod_dim, inner_dim * 2),
|
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
|
def modulate(x, shift, scale):
|
|
|
|
|
|
|
|
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
|
|
|
shift, scale = self.mlp(mod).chunk(2, dim=-1)
|
|
|
return self.modulate(self.norm(x), shift, scale)
|
|
|
|