File size: 4,159 Bytes
0c120cf
 
 
31677e7
 
 
 
 
 
0c120cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58b72ee
 
 
 
 
0c120cf
 
 
 
 
 
 
58b72ee
 
0c120cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torch.nn as nn
from einops import rearrange

from model.patch_embed import PatchEmbedding
from model.transformer_layer import TransformerLayer

# from patch_embed import PatchEmbedding
# from transformer_layer import TransformerLayer


def get_time_embedding(time_steps, temb_dim):
    factor = 10000 ** (
        torch.arange(
            0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device
        )
        // (temb_dim // 2)
    )

    t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
    return t_emb


class DIT(nn.Module):
    def __init__(self, config, im_size, im_channels) -> None:
        super().__init__()
        self.image_height = im_size
        self.image_width = im_size
        self.patch_height = config["patch_size"]
        self.patch_width = config["patch_size"]
        self.hidden_dim = config["hidden_dim"]
        self.num_layers = config["num_layers"]
        self.temb_dim = config["temb_dim"]
        self.nh = self.image_height // self.patch_height
        self.nw = self.image_width // self.patch_width

        self.patch_embd_layer = PatchEmbedding(
            self.image_height,
            self.image_width,
            self.patch_height,
            self.patch_width,
            self.hidden_dim,
            im_channels,
        )

        self.layers = nn.ModuleList(
            [TransformerLayer(config) for _ in range(self.num_layers)]
        )

        # Project the time step embedding to hidden dim
        self.t_proj = nn.Sequential(
            nn.Linear(self.temb_dim, self.hidden_dim),
            nn.SiLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
        )

        # Norm layer before the unpatchify layer
        self.norm = nn.LayerNorm(self.hidden_dim, elementwise_affine=False)

        # Scale and shift features for the norm layer
        self.ada_norm_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(self.hidden_dim, 2 * self.hidden_dim, bias=True),
        )

        # Final Projection
        self.out_proj = nn.Linear(self.hidden_dim, 2 * self.patch_height * im_channels)

        nn.init.normal_(self.t_proj[0].weight, std=0.02)
        nn.init.normal_(self.t_proj[2].weight, std=0.02)

        nn.init.constant_(self.ada_norm_layer[-1].weight, 0)
        nn.init.constant_(self.ada_norm_layer[-1].bias, 0)

        nn.init.constant_(self.out_proj.weight, 0)
        nn.init.constant_(self.out_proj.bias, 0)

    def forward(self, x, t):
        # Patchify
        out = self.patch_embd_layer(x)

        # Get temb and then project it
        temb = get_time_embedding(torch.as_tensor(t).long(), self.temb_dim)

        temb = self.t_proj(temb)

        for layer in self.layers:
            out = layer(out, temb)

        pre_mlp_shift, pre_mlp_scale = self.ada_norm_layer(temb).chunk(2, dim=1)
        out = self.norm(out) * (
            1 + pre_mlp_scale.unsqueeze(1)
        ) + pre_mlp_shift.unsqueeze(1)

        actual_h = x.shape[2]  # Height from input tensor
        actual_w = x.shape[3]  # Width from input tensor
        actual_nh = actual_h // self.patch_height
        actual_nw = actual_w // self.patch_width

        # Unpatichify
        out = self.out_proj(out)
        out = rearrange(
            out,
            "b (nh nw) (ph pw c) -> b c (nh ph) (nw pw)",
            ph=self.patch_height,
            pw=self.patch_width,
            nw=actual_nw,
            nh=actual_nh,
        )

        return out


# if __name__ == "__main__":
#    config = {
#        "patch_size": 2,
#        "hidden_dim": 12,
#        "num_layers": 1,
#        "temb_dim": 128,
#        "num_heads": 4,
#        "head_dim": 64,
#    }
#
#    # Test parameters
#    im_size = 32  # 32x32 image
#    im_channels = 3  # RGB
#    batch_size = 2
#
#    # Create test data
#    x = torch.randn(batch_size, im_channels, im_size, im_size)
#    t = torch.randint(0, 1000, (batch_size,))
#
#    # Initialize model
#    model = DIT(config, im_size, im_channels)
#
#    # Forward pass
#    with torch.no_grad():
#        output = model(x, t)