Diffusion-Transformer / model /transformer.py
YashNagraj75's picture
Remove print statements
d66ddca
import torch
import torch.nn as nn
from einops import rearrange
from model.patch_embed import PatchEmbedding
from model.transformer_layer import TransformerLayer
# from patch_embed import PatchEmbedding
# from transformer_layer import TransformerLayer
def get_time_embedding(time_steps, temb_dim):
factor = 10000 ** (
torch.arange(
0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device
)
// (temb_dim // 2)
)
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
class DIT(nn.Module):
def __init__(self, config, im_size, im_channels) -> None:
super().__init__()
self.image_height = im_size
self.image_width = im_size
self.patch_height = config["patch_size"]
self.patch_width = config["patch_size"]
self.hidden_dim = config["hidden_dim"]
self.num_layers = config["num_layers"]
self.temb_dim = config["temb_dim"]
self.nh = self.image_height // self.patch_height
self.nw = self.image_width // self.patch_width
self.patch_embd_layer = PatchEmbedding(
self.image_height,
self.image_width,
self.patch_height,
self.patch_width,
self.hidden_dim,
im_channels,
)
self.layers = nn.ModuleList(
[TransformerLayer(config) for _ in range(self.num_layers)]
)
# Project the time step embedding to hidden dim
self.t_proj = nn.Sequential(
nn.Linear(self.temb_dim, self.hidden_dim),
nn.SiLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
)
# Norm layer before the unpatchify layer
self.norm = nn.LayerNorm(self.hidden_dim, elementwise_affine=False)
# Scale and shift features for the norm layer
self.ada_norm_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(self.hidden_dim, 2 * self.hidden_dim, bias=True),
)
# Final Projection
self.out_proj = nn.Linear(self.hidden_dim, 2 * self.patch_height * im_channels)
nn.init.normal_(self.t_proj[0].weight, std=0.02)
nn.init.normal_(self.t_proj[2].weight, std=0.02)
nn.init.constant_(self.ada_norm_layer[-1].weight, 0)
nn.init.constant_(self.ada_norm_layer[-1].bias, 0)
nn.init.constant_(self.out_proj.weight, 0)
nn.init.constant_(self.out_proj.bias, 0)
def forward(self, x, t):
# Patchify
out = self.patch_embd_layer(x)
# Get temb and then project it
temb = get_time_embedding(torch.as_tensor(t).long(), self.temb_dim)
temb = self.t_proj(temb)
for layer in self.layers:
out = layer(out, temb)
pre_mlp_shift, pre_mlp_scale = self.ada_norm_layer(temb).chunk(2, dim=1)
out = self.norm(out) * (
1 + pre_mlp_scale.unsqueeze(1)
) + pre_mlp_shift.unsqueeze(1)
actual_h = x.shape[2] # Height from input tensor
actual_w = x.shape[3] # Width from input tensor
actual_nh = actual_h // self.patch_height
actual_nw = actual_w // self.patch_width
# Unpatichify
out = self.out_proj(out)
out = rearrange(
out,
"b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)",
ph=self.patch_height,
pw=self.patch_width,
nw=actual_nw,
nh=actual_nh,
)
return out
# if __name__ == "__main__":
# config = {
# "patch_size": 2,
# "hidden_dim": 12,
# "num_layers": 1,
# "temb_dim": 128,
# "num_heads": 4,
# "head_dim": 64,
# }
#
# # Test parameters
# im_size = 32 # 32x32 image
# im_channels = 3 # RGB
# batch_size = 2
#
# # Create test data
# x = torch.randn(batch_size, im_channels, im_size, im_size)
# t = torch.randint(0, 1000, (batch_size,))
#
# # Initialize model
# model = DIT(config, im_size, im_channels)
#
# # Forward pass
# with torch.no_grad():
# output = model(x, t)