kaupane commited on
Commit
bd7b3d6
·
verified ·
1 Parent(s): da80cd1

Upload modeling_dit_wikiart.py

Browse files
Files changed (1) hide show
  1. modeling_dit_wikiart.py +178 -0
modeling_dit_wikiart.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+ from typing import Union, Optional, Tuple
7
+
8
+
9
+ class TimestepEmbedder(nn.Module):
10
+ """Module to create timestep's embedding."""
11
+ def __init__(self,hidden_size,frequency_embedding_size=256):
12
+ super().__init__()
13
+ self.mlp = nn.Sequential(
14
+ nn.Linear(frequency_embedding_size,hidden_size),
15
+ nn.SiLU(),
16
+ nn.Linear(hidden_size,hidden_size)
17
+ )
18
+ self.frequency_embedding_size = frequency_embedding_size
19
+
20
+ def forward(self, t):
21
+ half = self.frequency_embedding_size // 2
22
+ freqs = torch.exp(
23
+ -math.log(10000) * torch.arange(start=0,end=half) / half
24
+ ).to(device=t.device)
25
+ args = torch.einsum('i,j->ij', t, freqs.to(t.device))
26
+ freqs = torch.cat([torch.cos(args),torch.sin(args)],dim=-1)
27
+
28
+ mlp_input_dtype = next(self.mlp.parameters()).dtype
29
+ freqs_casted = freqs.to(mlp_input_dtype)
30
+
31
+ return self.mlp(freqs_casted)
32
+
33
+ class ViTAttn(nn.Module):
34
+ def __init__(self,hidden_size,num_heads):
35
+ super().__init__()
36
+ self.attn = nn.MultiheadAttention(hidden_size,num_heads,bias=True,add_bias_kv=True,batch_first=True)
37
+
38
+ def forward(self,x):
39
+ attn, _ = self.attn(x,x,x)
40
+ return attn
41
+
42
+ class DiTBlock(nn.Module):
43
+ """
44
+ DiT Block with adaptive layer norm zero (adaLN-Zero) conditioning.
45
+ Using post-norm
46
+ """
47
+ def __init__(self,hidden_size,num_heads):
48
+ super().__init__()
49
+ self.norm1 = nn.LayerNorm(hidden_size,elementwise_affine=False,eps=1e-6)
50
+ self.attn = ViTAttn(hidden_size,num_heads)
51
+ self.norm2 = nn.LayerNorm(hidden_size,elementwise_affine=False,eps=1e-6)
52
+ self.mlp = nn.Sequential(
53
+ nn.Linear(hidden_size,4*hidden_size),
54
+ nn.GELU(approximate="tanh"),
55
+ nn.Linear(4*hidden_size,hidden_size)
56
+ )
57
+ self.adaLN = nn.Sequential(
58
+ nn.SiLU(),
59
+ nn.Linear(hidden_size,6*hidden_size)
60
+ )
61
+
62
+ def forward(self,x,c):
63
+ gamma_1,beta_1,alpha_1,gamma_2,beta_2,alpha_2 = self.adaLN(c).chunk(6,dim=1)
64
+ x = self.norm1(x + alpha_1.unsqueeze(1) * self.attn(x))
65
+ x = x * (1+gamma_1.unsqueeze(1)) + beta_1.unsqueeze(1)
66
+ x = self.norm2(x + alpha_2.unsqueeze(1) * self.mlp(x))
67
+ x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1)
68
+ return x
69
+
70
+ class DiTWikiartModel(nn.Module,
71
+ PyTorchModelHubMixin):
72
+ def __init__(self,
73
+ num_blocks=8,
74
+ hidden_size=384,
75
+ num_heads=6,
76
+ patch_size=2,
77
+ num_channels=4,
78
+ img_size=32,
79
+ num_genres=42,
80
+ num_styles=137):
81
+ super().__init__()
82
+ self.hidden_size = hidden_size
83
+ self.patch_size = patch_size
84
+ self.num_channels = num_channels
85
+ self.seq_len = (img_size // patch_size)**2
86
+ self.img_size = img_size
87
+ self.blocks = nn.ModuleList(
88
+ DiTBlock(hidden_size,num_heads) for _ in range(num_blocks)
89
+ )
90
+ self.timestep_embed = TimestepEmbedder(hidden_size)
91
+
92
+ self.num_genres = num_genres
93
+ self.num_styles = num_styles
94
+ self.genre_condition = nn.Embedding(num_genres+1,hidden_size) # +1 for null condition
95
+ self.style_condition = nn.Embedding(num_styles+1,hidden_size)
96
+
97
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, hidden_size))
98
+
99
+ patch_dim = num_channels * patch_size * patch_size
100
+ self.proj_in = nn.Linear(patch_dim,hidden_size)
101
+ self.proj_out = nn.Linear(hidden_size,patch_dim)
102
+
103
+ self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
104
+ self.adaLN_final = nn.Sequential(
105
+ nn.SiLU(),
106
+ nn.Linear(hidden_size, 2*hidden_size)
107
+ )
108
+
109
+ self.initialize_weights()
110
+
111
+ def initialize_weights(self):
112
+ nn.init.normal_(self.pos_embed, std=0.02)
113
+ nn.init.normal_(self.proj_out.weight, std=0.02)
114
+ nn.init.zeros_(self.proj_out.bias)
115
+ nn.init.normal_(self.proj_in.weight, std=0.02)
116
+ nn.init.zeros_(self.proj_in.bias)
117
+
118
+ nn.init.normal_(self.timestep_embed.mlp[0].weight, std=0.02)
119
+ nn.init.zeros_(self.timestep_embed.mlp[0].bias)
120
+ nn.init.normal_(self.timestep_embed.mlp[2].weight, std=0.02)
121
+ nn.init.zeros_(self.timestep_embed.mlp[2].bias)
122
+
123
+ for block in self.blocks:
124
+ nn.init.zeros_(block.adaLN[-1].weight)
125
+ nn.init.zeros_(block.adaLN[-1].bias)
126
+
127
+ nn.init.zeros_(self.adaLN_final[-1].weight)
128
+ nn.init.zeros_(self.adaLN_final[-1].bias)
129
+
130
+ nn.init.normal_(self.genre_condition.weight, std=0.02)
131
+ nn.init.normal_(self.style_condition.weight, std=0.02)
132
+
133
+ def patchify(self,z):
134
+ """
135
+ from (batch_size,6,32,32) -> (batch_size,256,24) -> (batch_size,256,hidden_size)
136
+ """
137
+ b,_,_,_ = z.shape
138
+ c = self.num_channels
139
+ p = self.patch_size
140
+ z = z.unfold(2,p,p).unfold(3,p,p) # (b,c,h//p,p,w//p,p)
141
+ z = z.contiguous().view(b,c,-1,p,p) # (b,c,hw//p**2,p,p)
142
+ z = torch.einsum('bcapq->bacpq',z).contiguous().view(b,-1,c*p**2) # (b,hw//p**2,c*p**2)
143
+ return self.proj_in(z) # (b,hw//p**2,hidden_size)
144
+
145
+ def unpatchify(self,z):
146
+ """
147
+ from (batch_size,256,hidden_size) -> (batch_size,256,24) -> (batch_size,6,32,32)
148
+ """
149
+ b,_,_ = z.shape
150
+ c = self.num_channels
151
+ p = self.patch_size
152
+ s = int(self.seq_len ** 0.5)
153
+ i = self.img_size
154
+ z = self.proj_out(z) # (b,hw//p**2,c*p**2)
155
+ z = z.view(b,s,s,c,p,p) # (b,h/p,w/p,c,p,p)
156
+ z = torch.einsum('befcpq->bcepfq',z) # (b,c,h/p,p,w/p,p)
157
+ z = z.contiguous().view(b,c,i,i)
158
+ return z
159
+
160
+ def forward(self,z,t,g,s):
161
+ t_embed = self.timestep_embed(t) # t_embed: (batch_size, hidden_size)
162
+ g_embed = self.genre_condition(g)
163
+ s_embed = self.style_condition(s)
164
+
165
+ c = t_embed + g_embed + s_embed
166
+
167
+ z = self.patchify(z)
168
+ z = z + self.pos_embed
169
+
170
+ for block in self.blocks:
171
+ z = block(z,c)
172
+
173
+ gamma, beta = self.adaLN_final(c).chunk(2,dim=-1)
174
+ z = self.norm_out(z)
175
+ z = z * (1+gamma.unsqueeze(1)) + beta.unsqueeze(1)
176
+
177
+ return self.unpatchify(z)
178
+