File size: 4,159 Bytes
0c120cf 31677e7 0c120cf 58b72ee 0c120cf 58b72ee 0c120cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import torch
import torch.nn as nn
from einops import rearrange
from model.patch_embed import PatchEmbedding
from model.transformer_layer import TransformerLayer
# from patch_embed import PatchEmbedding
# from transformer_layer import TransformerLayer
def get_time_embedding(time_steps, temb_dim):
factor = 10000 ** (
torch.arange(
0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device
)
// (temb_dim // 2)
)
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
class DIT(nn.Module):
def __init__(self, config, im_size, im_channels) -> None:
super().__init__()
self.image_height = im_size
self.image_width = im_size
self.patch_height = config["patch_size"]
self.patch_width = config["patch_size"]
self.hidden_dim = config["hidden_dim"]
self.num_layers = config["num_layers"]
self.temb_dim = config["temb_dim"]
self.nh = self.image_height // self.patch_height
self.nw = self.image_width // self.patch_width
self.patch_embd_layer = PatchEmbedding(
self.image_height,
self.image_width,
self.patch_height,
self.patch_width,
self.hidden_dim,
im_channels,
)
self.layers = nn.ModuleList(
[TransformerLayer(config) for _ in range(self.num_layers)]
)
# Project the time step embedding to hidden dim
self.t_proj = nn.Sequential(
nn.Linear(self.temb_dim, self.hidden_dim),
nn.SiLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
)
# Norm layer before the unpatchify layer
self.norm = nn.LayerNorm(self.hidden_dim, elementwise_affine=False)
# Scale and shift features for the norm layer
self.ada_norm_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(self.hidden_dim, 2 * self.hidden_dim, bias=True),
)
# Final Projection
self.out_proj = nn.Linear(self.hidden_dim, 2 * self.patch_height * im_channels)
nn.init.normal_(self.t_proj[0].weight, std=0.02)
nn.init.normal_(self.t_proj[2].weight, std=0.02)
nn.init.constant_(self.ada_norm_layer[-1].weight, 0)
nn.init.constant_(self.ada_norm_layer[-1].bias, 0)
nn.init.constant_(self.out_proj.weight, 0)
nn.init.constant_(self.out_proj.bias, 0)
def forward(self, x, t):
# Patchify
out = self.patch_embd_layer(x)
# Get temb and then project it
temb = get_time_embedding(torch.as_tensor(t).long(), self.temb_dim)
temb = self.t_proj(temb)
for layer in self.layers:
out = layer(out, temb)
pre_mlp_shift, pre_mlp_scale = self.ada_norm_layer(temb).chunk(2, dim=1)
out = self.norm(out) * (
1 + pre_mlp_scale.unsqueeze(1)
) + pre_mlp_shift.unsqueeze(1)
actual_h = x.shape[2] # Height from input tensor
actual_w = x.shape[3] # Width from input tensor
actual_nh = actual_h // self.patch_height
actual_nw = actual_w // self.patch_width
# Unpatichify
out = self.out_proj(out)
out = rearrange(
out,
"b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)",
ph=self.patch_height,
pw=self.patch_width,
nw=actual_nw,
nh=actual_nh,
)
return out
# if __name__ == "__main__":
# config = {
# "patch_size": 2,
# "hidden_dim": 12,
# "num_layers": 1,
# "temb_dim": 128,
# "num_heads": 4,
# "head_dim": 64,
# }
#
# # Test parameters
# im_size = 32 # 32x32 image
# im_channels = 3 # RGB
# batch_size = 2
#
# # Create test data
# x = torch.randn(batch_size, im_channels, im_size, im_size)
# t = torch.randint(0, 1000, (batch_size,))
#
# # Initialize model
# model = DIT(config, im_size, im_channels)
#
# # Forward pass
# with torch.no_grad():
# output = model(x, t)
|