BiliSakura's picture
Add files using upload-large-folder tool
66a2b45 verified
"""GlobalTextAdapter - FFN-based adapter for global text conditioning."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out),
)
def forward(self, x):
return self.net(x)
class GlobalTextAdapter(nn.Module):
def __init__(self, in_dim, max_len=768):
super().__init__()
self.in_dim = in_dim
dim_out1 = in_dim * 2
dim_out2 = in_dim
self.ff1 = FeedForward(in_dim, dim_out=dim_out1, mult=2, glu=True, dropout=0.0)
self.ff2 = FeedForward(dim_out1, dim_out=dim_out2, mult=4, glu=True, dropout=0.0)
self.norm1 = nn.LayerNorm(in_dim)
self.norm2 = nn.LayerNorm(dim_out1)
def forward(self, x):
x = self.ff1(self.norm1(x))
x = self.ff2(self.norm2(x))
return x