import torch import torch.nn as nn from .modeling_promoe_common import ( Attention, FinalLayer, LabelEmbedder, Mlp, PatchEmbed, TimestepEmbedder, get_2d_sincos_pos_embed, modulate, ) class DiTBlock(nn.Module): def __init__(self, hidden_size, num_heads, head_dim=None, mlp_ratio=4.0, use_swiglu=False, **block_kwargs): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) def forward(self, x, c): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x class DiT(nn.Module): def __init__( self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, qk_norm=False, class_dropout_prob=0.1, num_classes=1000, learn_sigma=True, head_dim=None, use_swiglu=False, ): super().__init__() self.learn_sigma = learn_sigma self.in_channels = in_channels self.out_channels = in_channels * 2 if learn_sigma else in_channels self.patch_size = patch_size self.num_heads = num_heads self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) self.t_embedder = TimestepEmbedder(hidden_size) self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) num_patches = self.x_embedder.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) self.blocks = nn.ModuleList( [ DiTBlock( hidden_size, num_heads, head_dim=head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm, use_swiglu=use_swiglu, ) for _ in range(depth) ] ) self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) self.initialize_weights() def initialize_weights(self): def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.x_embedder.proj.bias, 0) nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def unpatchify(self, x): c = self.out_channels p = self.x_embedder.patch_size[0] h = w = int(x.shape[1] ** 0.5) x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum("nhwpqc->nchpwq", x) return x.reshape(shape=(x.shape[0], c, h * p, h * p)) def forward(self, x, t, context, **kwargs): y = context if len(x.shape) != 4: x = x.squeeze(2) x = self.x_embedder(x) + self.pos_embed t = self.t_embedder(t) y = self.y_embedder(y, self.training) c = t + y for block in self.blocks: x = block(x, c) x = self.final_layer(x, c) return self.unpatchify(x)