| import torch |
| import torch.nn as nn |
| from einops import rearrange |
|
|
| from model.patch_embed import PatchEmbedding |
| from model.transformer_layer import TransformerLayer |
|
|
| |
| |
|
|
|
|
| def get_time_embedding(time_steps, temb_dim): |
| factor = 10000 ** ( |
| torch.arange( |
| 0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device |
| ) |
| // (temb_dim // 2) |
| ) |
|
|
| t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor |
| t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) |
| return t_emb |
|
|
|
|
| class DIT(nn.Module): |
| def __init__(self, config, im_size, im_channels) -> None: |
| super().__init__() |
| self.image_height = im_size |
| self.image_width = im_size |
| self.patch_height = config["patch_size"] |
| self.patch_width = config["patch_size"] |
| self.hidden_dim = config["hidden_dim"] |
| self.num_layers = config["num_layers"] |
| self.temb_dim = config["temb_dim"] |
| self.nh = self.image_height // self.patch_height |
| self.nw = self.image_width // self.patch_width |
|
|
| self.patch_embd_layer = PatchEmbedding( |
| self.image_height, |
| self.image_width, |
| self.patch_height, |
| self.patch_width, |
| self.hidden_dim, |
| im_channels, |
| ) |
|
|
| self.layers = nn.ModuleList( |
| [TransformerLayer(config) for _ in range(self.num_layers)] |
| ) |
|
|
| |
| self.t_proj = nn.Sequential( |
| nn.Linear(self.temb_dim, self.hidden_dim), |
| nn.SiLU(), |
| nn.Linear(self.hidden_dim, self.hidden_dim), |
| ) |
|
|
| |
| self.norm = nn.LayerNorm(self.hidden_dim, elementwise_affine=False) |
|
|
| |
| self.ada_norm_layer = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(self.hidden_dim, 2 * self.hidden_dim, bias=True), |
| ) |
|
|
| |
| self.out_proj = nn.Linear(self.hidden_dim, 2 * self.patch_height * im_channels) |
|
|
| nn.init.normal_(self.t_proj[0].weight, std=0.02) |
| nn.init.normal_(self.t_proj[2].weight, std=0.02) |
|
|
| nn.init.constant_(self.ada_norm_layer[-1].weight, 0) |
| nn.init.constant_(self.ada_norm_layer[-1].bias, 0) |
|
|
| nn.init.constant_(self.out_proj.weight, 0) |
| nn.init.constant_(self.out_proj.bias, 0) |
|
|
| def forward(self, x, t): |
| |
| out = self.patch_embd_layer(x) |
|
|
| |
| temb = get_time_embedding(torch.as_tensor(t).long(), self.temb_dim) |
|
|
| temb = self.t_proj(temb) |
|
|
| for layer in self.layers: |
| out = layer(out, temb) |
|
|
| pre_mlp_shift, pre_mlp_scale = self.ada_norm_layer(temb).chunk(2, dim=1) |
| out = self.norm(out) * ( |
| 1 + pre_mlp_scale.unsqueeze(1) |
| ) + pre_mlp_shift.unsqueeze(1) |
|
|
| actual_h = x.shape[2] |
| actual_w = x.shape[3] |
| actual_nh = actual_h // self.patch_height |
| actual_nw = actual_w // self.patch_width |
|
|
| |
| out = self.out_proj(out) |
| out = rearrange( |
| out, |
| "b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)", |
| ph=self.patch_height, |
| pw=self.patch_width, |
| nw=actual_nw, |
| nh=actual_nh, |
| ) |
|
|
| return out |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|