import torch import torch.nn as nn from einops import rearrange from model.patch_embed import PatchEmbedding from model.transformer_layer import TransformerLayer # from patch_embed import PatchEmbedding # from transformer_layer import TransformerLayer def get_time_embedding(time_steps, temb_dim): factor = 10000 ** ( torch.arange( 0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device ) // (temb_dim // 2) ) t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) return t_emb class DIT(nn.Module): def __init__(self, config, im_size, im_channels) -> None: super().__init__() self.image_height = im_size self.image_width = im_size self.patch_height = config["patch_size"] self.patch_width = config["patch_size"] self.hidden_dim = config["hidden_dim"] self.num_layers = config["num_layers"] self.temb_dim = config["temb_dim"] self.nh = self.image_height // self.patch_height self.nw = self.image_width // self.patch_width self.patch_embd_layer = PatchEmbedding( self.image_height, self.image_width, self.patch_height, self.patch_width, self.hidden_dim, im_channels, ) self.layers = nn.ModuleList( [TransformerLayer(config) for _ in range(self.num_layers)] ) # Project the time step embedding to hidden dim self.t_proj = nn.Sequential( nn.Linear(self.temb_dim, self.hidden_dim), nn.SiLU(), nn.Linear(self.hidden_dim, self.hidden_dim), ) # Norm layer before the unpatchify layer self.norm = nn.LayerNorm(self.hidden_dim, elementwise_affine=False) # Scale and shift features for the norm layer self.ada_norm_layer = nn.Sequential( nn.SiLU(), nn.Linear(self.hidden_dim, 2 * self.hidden_dim, bias=True), ) # Final Projection self.out_proj = nn.Linear(self.hidden_dim, 2 * self.patch_height * im_channels) nn.init.normal_(self.t_proj[0].weight, std=0.02) nn.init.normal_(self.t_proj[2].weight, std=0.02) nn.init.constant_(self.ada_norm_layer[-1].weight, 0) nn.init.constant_(self.ada_norm_layer[-1].bias, 0) nn.init.constant_(self.out_proj.weight, 0) nn.init.constant_(self.out_proj.bias, 0) def forward(self, x, t): # Patchify out = self.patch_embd_layer(x) # Get temb and then project it temb = get_time_embedding(torch.as_tensor(t).long(), self.temb_dim) temb = self.t_proj(temb) for layer in self.layers: out = layer(out, temb) pre_mlp_shift, pre_mlp_scale = self.ada_norm_layer(temb).chunk(2, dim=1) out = self.norm(out) * ( 1 + pre_mlp_scale.unsqueeze(1) ) + pre_mlp_shift.unsqueeze(1) actual_h = x.shape[2] # Height from input tensor actual_w = x.shape[3] # Width from input tensor actual_nh = actual_h // self.patch_height actual_nw = actual_w // self.patch_width # Unpatichify out = self.out_proj(out) out = rearrange( out, "b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)", ph=self.patch_height, pw=self.patch_width, nw=actual_nw, nh=actual_nh, ) return out # if __name__ == "__main__": # config = { # "patch_size": 2, # "hidden_dim": 12, # "num_layers": 1, # "temb_dim": 128, # "num_heads": 4, # "head_dim": 64, # } # # # Test parameters # im_size = 32 # 32x32 image # im_channels = 3 # RGB # batch_size = 2 # # # Create test data # x = torch.randn(batch_size, im_channels, im_size, im_size) # t = torch.randint(0, 1000, (batch_size,)) # # # Initialize model # model = DIT(config, im_size, im_channels) # # # Forward pass # with torch.no_grad(): # output = model(x, t)