| """GlobalContentAdapter - FFN-based adapter for global content conditioning.""" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
|
|
| class GEGLU(nn.Module): |
| def __init__(self, dim_in, dim_out): |
| super().__init__() |
| self.proj = nn.Linear(dim_in, dim_out * 2) |
|
|
| def forward(self, x): |
| x, gate = self.proj(x).chunk(2, dim=-1) |
| return x * F.gelu(gate) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): |
| super().__init__() |
| inner_dim = int(dim * mult) |
| dim_out = dim_out if dim_out is not None else dim |
| project_in = ( |
| nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) |
| if not glu |
| else GEGLU(dim, inner_dim) |
| ) |
| self.net = nn.Sequential( |
| project_in, |
| nn.Dropout(dropout), |
| nn.Linear(inner_dim, dim_out), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class GlobalContentAdapter(nn.Module): |
| def __init__(self, in_dim, channel_mult=None): |
| super().__init__() |
| channel_mult = channel_mult or [2, 4] |
| dim_out1, mult1 = in_dim * channel_mult[0], channel_mult[0] * 2 |
| dim_out2, mult2 = in_dim * channel_mult[1], channel_mult[1] * 2 // channel_mult[0] |
| self.in_dim = in_dim |
| self.channel_mult = channel_mult |
| self.ff1 = FeedForward(in_dim, dim_out=dim_out1, mult=mult1, glu=True, dropout=0.0) |
| self.ff2 = FeedForward(dim_out1, dim_out=dim_out2, mult=mult2, glu=True, dropout=0.0) |
| self.norm1 = nn.LayerNorm(in_dim) |
| self.norm2 = nn.LayerNorm(dim_out1) |
|
|
| def forward(self, x): |
| x = self.ff1(self.norm1(x)) |
| x = self.ff2(self.norm2(x)) |
| x = rearrange(x, "b (n d) -> b n d", n=self.channel_mult[-1], d=self.in_dim).contiguous() |
| return x |
|
|