"""Sage-T2I: Custom Diffusion Transformer for Text-to-Image generation.""" import torch import torch.nn as nn import torch.nn.functional as F import math def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, freq_embed_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(freq_embed_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.freq_embed_size = freq_embed_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(half, dtype=torch.float32, device=t.device) / half) args = t[:, None].float() * freqs[None] return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) def forward(self, t): t_freq = self.timestep_embedding(t, self.freq_embed_size) t_emb = self.mlp(t_freq) return t_emb class CaptionEmbedder(nn.Module): def __init__(self, in_channels, hidden_size, act_layer=nn.SiLU): super().__init__() self.linear = nn.Linear(in_channels, hidden_size, bias=True) self.act = act_layer() def forward(self, x): return self.act(self.linear(x)) class SelfAttention(nn.Module): def __init__(self, hidden_size, num_heads): super().__init__() self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True) self.proj = nn.Linear(hidden_size, hidden_size, bias=True) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x class CrossAttention(nn.Module): def __init__(self, hidden_size, context_dim, num_heads): super().__init__() self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.q = nn.Linear(hidden_size, hidden_size, bias=True) self.k = nn.Linear(context_dim, hidden_size, bias=True) self.v = nn.Linear(context_dim, hidden_size, bias=True) self.proj = nn.Linear(hidden_size, hidden_size, bias=True) def forward(self, x, context): B, N, C = x.shape _, M, _ = context.shape q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) k = self.k(context).reshape(B, M, self.num_heads, self.head_dim).transpose(1, 2) v = self.v(context).reshape(B, M, self.num_heads, self.head_dim).transpose(1, 2) attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x class FeedForward(nn.Module): def __init__(self, hidden_size, mlp_ratio=4.0): super().__init__() hidden = int(hidden_size * mlp_ratio) self.fc1 = nn.Linear(hidden_size, hidden, bias=True) self.fc2 = nn.Linear(hidden, hidden_size, bias=True) self.act = nn.GELU(approximate="tanh") def forward(self, x): return self.fc2(self.act(self.fc1(x))) class DiTBlock(nn.Module): def __init__(self, hidden_size, num_heads, context_dim, mlp_ratio=4.0): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.self_attn = SelfAttention(hidden_size, num_heads) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.cross_attn = CrossAttention(hidden_size, context_dim, num_heads) self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(hidden_size, mlp_ratio) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) def forward(self, x, c, context): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \ self.adaLN_modulation(c).chunk(6, dim=1) x = x + gate_msa.unsqueeze(1) * self.self_attn(modulate(self.norm1(x), shift_msa, scale_msa)) x = x + self.cross_attn(self.norm2(x), context) x = x + gate_mlp.unsqueeze(1) * self.ff(modulate(self.norm3(x), shift_mlp, scale_mlp)) return x class FinalLayer(nn.Module): def __init__(self, hidden_size, patch_size, in_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * in_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class DiT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.in_channels = config.in_channels self.hidden_size = config.hidden_size self.patch_size = config.patch_size self.num_heads = config.num_heads self.x_embedder = nn.Linear(config.patch_size * config.patch_size * config.in_channels, config.hidden_size, bias=True) self.t_embedder = TimestepEmbedder(config.hidden_size) self.c_embedder = CaptionEmbedder(config.context_dim, config.hidden_size) latent_size = config.image_size // 8 num_patches = (latent_size // config.patch_size) ** 2 self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size), requires_grad=True) self.blocks = nn.ModuleList([ DiTBlock(config.hidden_size, config.num_heads, config.context_dim, config.mlp_ratio) for _ in range(config.num_layers) ]) self.final_layer = FinalLayer(config.hidden_size, config.patch_size, config.in_channels) self.initialize_weights() def initialize_weights(self): nn.init.normal_(self.x_embedder.weight, std=0.02) nn.init.normal_(self.pos_embed, std=0.02) nn.init.normal_(self.c_embedder.linear.weight, std=0.02) for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def forward(self, x, t, context): B, C, H, W = x.shape x = x.reshape(B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size) x = x.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * self.patch_size * self.patch_size) x = self.x_embedder(x) + self.pos_embed t_emb = self.t_embedder(t) c_emb = self.c_embedder(context).mean(dim=1) c = t_emb + c_emb for block in self.blocks: x = block(x, c, context) x = self.final_layer(x, c) x = x.reshape(B, H // self.patch_size, W // self.patch_size, self.patch_size, self.patch_size, self.in_channels) x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, self.in_channels, H, W) return x