BiliSakura's picture
Upload folder using huggingface_hub
24196fc verified
Raw
History Blame Contribute Delete
4.88 kB
import torch
import torch.nn as nn
from .modeling_promoe_common import (
Attention,
FinalLayer,
LabelEmbedder,
Mlp,
PatchEmbed,
TimestepEmbedder,
get_2d_sincos_pos_embed,
modulate,
)
class DiTBlock(nn.Module):
def __init__(self, hidden_size, num_heads, head_dim=None, mlp_ratio=4.0, use_swiglu=False, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class DiT(nn.Module):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
qk_norm=False,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
head_dim=None,
use_swiglu=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
num_patches = self.x_embedder.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
self.blocks = nn.ModuleList(
[
DiTBlock(
hidden_size,
num_heads,
head_dim=head_dim,
mlp_ratio=mlp_ratio,
qk_norm=qk_norm,
use_swiglu=use_swiglu,
)
for _ in range(depth)
]
)
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def unpatchify(self, x):
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
return x.reshape(shape=(x.shape[0], c, h * p, h * p))
def forward(self, x, t, context, **kwargs):
y = context
if len(x.shape) != 4:
x = x.squeeze(2)
x = self.x_embedder(x) + self.pos_embed
t = self.t_embedder(t)
y = self.y_embedder(y, self.training)
c = t + y
for block in self.blocks:
x = block(x, c)
x = self.final_layer(x, c)
return self.unpatchify(x)