NoteDance commited on
Commit
e07f172
·
verified ·
1 Parent(s): 7cc8696

Upload DiT.py

Browse files
Files changed (1) hide show
  1. DiT.py +440 -0
DiT.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.layers import Conv2d,Dense,Dropout,LayerNormalization,Activation
3
+ from tensorflow.keras.initializers import RandomNormal
4
+ from tensorflow.keras import Model
5
+ import collections.abc
6
+ from itertools import repeat
7
+ from typing import Optional
8
+ import numpy as np
9
+ import math
10
+
11
+
12
+ def modulate(x, shift, scale):
13
+ return x * (1 + tf.expand_dims(scale, 1)) + tf.expand_dims(shift, 1)
14
+
15
+
16
+ #################################################################################
17
+ # Embedding Layers for Timesteps and Class Labels #
18
+ #################################################################################
19
+
20
+ class TimestepEmbedder:
21
+ """
22
+ Embeds scalar timesteps into vector representations.
23
+ """
24
+ def __init__(self, hidden_size, frequency_embedding_size=256):
25
+ self.mlp = tf.keras.Sequential()
26
+ self.mlp.add(Dense(hidden_size, kernel_initializer=RandomNormal(stddev=0.02), use_bias=True))
27
+ self.mlp.add(Activation('silu'))
28
+ self.mlp.add(Dense(hidden_size, kernel_initializer=RandomNormal(stddev=0.02), use_bias=True))
29
+ self.frequency_embedding_size = frequency_embedding_size
30
+
31
+ @staticmethod
32
+ def timestep_embedding(t, dim, max_period=10000):
33
+ """
34
+ Create sinusoidal timestep embeddings.
35
+ :param t: a 1-D Tensor of N indices, one per batch element.
36
+ These may be fractional.
37
+ :param dim: the dimension of the output.
38
+ :param max_period: controls the minimum frequency of the embeddings.
39
+ :return: an (N, D) Tensor of positional embeddings.
40
+ """
41
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
42
+ half = dim // 2
43
+ freqs = tf.math.exp(
44
+ -math.log(max_period) * tf.range(start=0, limit=half, dtype=tf.float32) / half
45
+ )
46
+ args = tf.cast(t[:, None], 'float32') * freqs[None]
47
+ embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], axis=-1)
48
+ if dim % 2:
49
+ embedding = tf.concat([embedding, tf.zeros_like(embedding[:, :1])], axis=-1)
50
+ return embedding
51
+
52
+ def __call__(self, t):
53
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
54
+ t_emb = self.mlp(t_freq)
55
+ return t_emb
56
+
57
+
58
+ class LabelEmbedder:
59
+ """
60
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
61
+ """
62
+ def __init__(self, num_classes, hidden_size, dropout_prob):
63
+ use_cfg_embedding = dropout_prob > 0
64
+ self.embedding_table = tf.Variable(tf.random.normal((num_classes + use_cfg_embedding, hidden_size), stddev=0.02))
65
+ self.num_classes = num_classes
66
+ self.dropout_prob = dropout_prob
67
+
68
+ def token_drop(self, labels, force_drop_ids=None):
69
+ """
70
+ Drops labels to enable classifier-free guidance.
71
+ """
72
+ if force_drop_ids is None:
73
+ drop_ids = tf.random.uniform([labels.shape[0]]) < self.dropout_prob
74
+ else:
75
+ drop_ids = force_drop_ids == 1
76
+ labels = tf.where(drop_ids, self.num_classes, labels)
77
+ return labels
78
+
79
+ def __call__(self, labels, train, force_drop_ids=None):
80
+ use_dropout = self.dropout_prob > 0
81
+ if (train and use_dropout) or (force_drop_ids is not None):
82
+ labels = self.token_drop(labels, force_drop_ids)
83
+ embeddings = tf.gather(self.embedding_table, labels)
84
+ return embeddings
85
+
86
+
87
+ #################################################################################
88
+ # Core DiT Model #
89
+ #################################################################################
90
+
91
+ class DiTBlock:
92
+ """
93
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
94
+ """
95
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
96
+ self.norm1 = LayerNormalization(epsilon=1e-6)
97
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
98
+ self.norm2 = LayerNormalization(epsilon=1e-6)
99
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
100
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
101
+ self.adaLN_modulation = tf.keras.Sequential()
102
+ self.adaLN_modulation.add(Activation('silu'))
103
+ self.adaLN_modulation.add(Dense(6 * hidden_size, kernel_initializer='zeros', use_bias=True))
104
+
105
+ def __call__(self, x, c):
106
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = tf.split(self.adaLN_modulation(c), num_or_size_splits=6, axis=1)
107
+ x = x + tf.expand_dims(gate_msa, 1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
108
+ x = x + tf.expand_dims(gate_mlp, 1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
109
+ return x
110
+
111
+
112
+ class FinalLayer:
113
+ """
114
+ The final layer of DiT.
115
+ """
116
+ def __init__(self, hidden_size, patch_size, out_channels):
117
+ self.norm_final = LayerNormalization(epsilon=1e-6)
118
+ self.linear = Dense(patch_size * patch_size * out_channels, kernel_initializer='zeros', use_bias=True)
119
+ self.adaLN_modulation = tf.keras.Sequential()
120
+ self.adaLN_modulation.add(Activation('silu'))
121
+ self.adaLN_modulation.add(Dense(2 * hidden_size, kernel_initializer='zeros', use_bias=True))
122
+
123
+ def __call__(self, x, c):
124
+ shift, scale = tf.split(self.adaLN_modulation(c), num_or_size_splits=2, axis=1)
125
+ x = modulate(self.norm_final(x), shift, scale)
126
+ x = self.linear(x)
127
+ return x
128
+
129
+
130
+ class DiT(Model):
131
+ """
132
+ Diffusion model with a Transformer backbone.
133
+ """
134
+ def __init__(
135
+ self,
136
+ input_size=32,
137
+ patch_size=2,
138
+ in_channels=4,
139
+ hidden_size=1152,
140
+ depth=28,
141
+ num_heads=16,
142
+ mlp_ratio=4.0,
143
+ class_dropout_prob=0.1,
144
+ num_classes=1000,
145
+ learn_sigma=True,
146
+ ):
147
+ super(DiT, self).__init__()
148
+ self.learn_sigma = learn_sigma
149
+ self.in_channels = in_channels
150
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
151
+ self.patch_size = patch_size
152
+ self.num_heads = num_heads
153
+
154
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
155
+ self.t_embedder = TimestepEmbedder(hidden_size)
156
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
157
+ num_patches = self.x_embedder.num_patches
158
+ # Will use fixed sin-cos embedding:
159
+ self.pos_embed = tf.zeros((1, num_patches, hidden_size))
160
+
161
+ self.blocks = [
162
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
163
+ ]
164
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
165
+ self.initialize_weights()
166
+
167
+ def initialize_weights(self):
168
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
169
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
170
+ self.pos_embed = tf.convert_to_tensor(pos_embed, dtype=tf.float32)[tf.newaxis, :]
171
+ tf.Variable(self.pos_embed)
172
+
173
+ def unpatchify(self, x):
174
+ """
175
+ x: (N, T, patch_size**2 * C)
176
+ imgs: (N, H, W, C)
177
+ """
178
+ c = self.out_channels
179
+ p = self.x_embedder.patch_size[0]
180
+ h = w = int(x.shape[1] ** 0.5)
181
+ assert h * w == x.shape[1]
182
+
183
+ x = tf.reshape(x, (x.shape[0], h, w, p, p, c))
184
+ x = tf.einsum('nhwpqc->nchpwq', x)
185
+ imgs = tf.reshape(x, (x.shape[0], h * p, h * p, c))
186
+ return imgs
187
+
188
+ def __call__(self, x, t, y):
189
+ """
190
+ Forward pass of DiT.
191
+ x: (N, H, W, C) tensor of spatial inputs (images or latent representations of images)
192
+ t: (N,) tensor of diffusion timesteps
193
+ y: (N,) tensor of class labels
194
+ """
195
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
196
+ t = self.t_embedder(t) # (N, D)
197
+ y = self.y_embedder(y, self.training) # (N, D)
198
+ c = t + y # (N, D)
199
+ for block in self.blocks:
200
+ x = block(x, c) # (N, T, D)
201
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
202
+ x = self.unpatchify(x) # (N, out_channels, H, W)
203
+ return x
204
+
205
+ def forward_with_cfg(self, x, t, y, cfg_scale):
206
+ """
207
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
208
+ """
209
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
210
+ half = x[: len(x) // 2]
211
+ combined = tf.concat([half, half], axis=0)
212
+ model_out = self.forward(combined, t, y)
213
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
214
+ # three channels by default. The standard approach to cfg applies it to all channels.
215
+ # This can be done by uncommenting the following line and commenting-out the line following that.
216
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
217
+ eps, rest = model_out[:, :3], model_out[:, 3:]
218
+ cond_eps, uncond_eps = tf.split(eps, len(eps) // 2, dim=0)
219
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
220
+ eps = tf.concat([half_eps, half_eps], axis=0)
221
+ return tf.concat([eps, rest], axis=1)
222
+
223
+
224
+ #################################################################################
225
+ # Sine/Cosine Positional Embedding Functions #
226
+ #################################################################################
227
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
228
+
229
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
230
+ """
231
+ grid_size: int of the grid height and width
232
+ return:
233
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
234
+ """
235
+ grid_h = np.arange(grid_size, dtype=np.float32)
236
+ grid_w = np.arange(grid_size, dtype=np.float32)
237
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
238
+ grid = np.stack(grid, axis=0)
239
+
240
+ grid = grid.reshape([2, 1, grid_size, grid_size])
241
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
242
+ if cls_token and extra_tokens > 0:
243
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
244
+ return pos_embed
245
+
246
+
247
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
248
+ assert embed_dim % 2 == 0
249
+
250
+ # use half of dimensions to encode grid_h
251
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
252
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
253
+
254
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
255
+ return emb
256
+
257
+
258
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
259
+ """
260
+ embed_dim: output dimension for each position
261
+ pos: a list of positions to be encoded: size (M,)
262
+ out: (M, D)
263
+ """
264
+ assert embed_dim % 2 == 0
265
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
266
+ omega /= embed_dim / 2.
267
+ omega = 1. / 10000**omega # (D/2,)
268
+
269
+ pos = pos.reshape(-1) # (M,)
270
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
271
+
272
+ emb_sin = np.sin(out) # (M, D/2)
273
+ emb_cos = np.cos(out) # (M, D/2)
274
+
275
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
276
+ return emb
277
+
278
+
279
+ #################################################################################
280
+ # DiT Configs #
281
+ #################################################################################
282
+
283
+ def DiT_XL_2():
284
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16)
285
+
286
+ def DiT_XL_4():
287
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16)
288
+
289
+ def DiT_XL_8():
290
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16)
291
+
292
+ def DiT_L_2():
293
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16)
294
+
295
+ def DiT_L_4():
296
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16)
297
+
298
+ def DiT_L_8():
299
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16)
300
+
301
+ def DiT_B_2():
302
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12)
303
+
304
+ def DiT_B_4():
305
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12)
306
+
307
+ def DiT_B_8():
308
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12)
309
+
310
+ def DiT_S_2():
311
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6)
312
+
313
+ def DiT_S_4():
314
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6)
315
+
316
+ def DiT_S_8():
317
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6)
318
+
319
+
320
+ DiT_models = {
321
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
322
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
323
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
324
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
325
+ }
326
+
327
+ def _ntuple(n):
328
+ def parse(x):
329
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
330
+ return tuple(x)
331
+ return tuple(repeat(x, n))
332
+ return parse
333
+
334
+
335
+ to_2tuple = _ntuple(2)
336
+
337
+
338
+ class PatchEmbed:
339
+ """ 2D Image to Patch Embedding
340
+ """
341
+ def __init__(
342
+ self,
343
+ img_size: Optional[int] = 224,
344
+ patch_size: int = 16,
345
+ in_chans: int = 3,
346
+ embed_dim: int = 768,
347
+ flatten: bool = True,
348
+ bias: bool = True,
349
+ ):
350
+ self.patch_size = to_2tuple(patch_size)
351
+ if img_size is not None:
352
+ self.img_size = to_2tuple(img_size)
353
+ self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
354
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
355
+ else:
356
+ self.img_size = None
357
+ self.grid_size = None
358
+ self.num_patches = None
359
+
360
+ # flatten spatial dim and transpose to channels last, kept for bwd compat
361
+ self.flatten = flatten
362
+
363
+ self.proj = Conv2d(embed_dim, kernel_size=patch_size, strides=patch_size, use_bias=bias)
364
+
365
+ def __call__(self, x):
366
+ x = self.proj(x)
367
+ B, H, W, C = x.shape
368
+ if self.flatten:
369
+ x = tf.reshape(x, [B, H*W, C]) # NHWC -> NLC
370
+ return x
371
+
372
+
373
+ class Mlp:
374
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
375
+ """
376
+ def __init__(
377
+ self,
378
+ in_features,
379
+ hidden_features=None,
380
+ out_features=None,
381
+ act_layer=tf.nn.gelu,
382
+ norm_layer=None,
383
+ bias=True,
384
+ drop=0.,
385
+ use_conv=False,
386
+ ):
387
+ out_features = out_features or in_features
388
+ hidden_features = hidden_features or in_features
389
+ bias = to_2tuple(bias)
390
+ drop_probs = to_2tuple(drop)
391
+
392
+ self.fc1 = Dense(hidden_features, use_bias=bias[0])
393
+ self.act = act_layer
394
+ self.drop1 = Dropout(drop_probs[0])
395
+ self.fc2 = Dense(out_features, use_bias=bias[1])
396
+ self.drop2 = Dropout(drop_probs[1])
397
+
398
+ def __call__(self, x):
399
+ x = self.fc1(x)
400
+ x = self.act(x, approximate="tanh")
401
+ x = self.drop1(x)
402
+ x = self.fc2(x)
403
+ x = self.drop2(x)
404
+ return x
405
+
406
+
407
+ class Attention:
408
+ def __init__(
409
+ self,
410
+ dim: int,
411
+ num_heads: int = 8,
412
+ qkv_bias: bool = False,
413
+ attn_drop: float = 0.,
414
+ proj_drop: float = 0.,
415
+ ):
416
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
417
+ self.num_heads = num_heads
418
+ self.head_dim = dim // num_heads
419
+ self.scale = self.head_dim ** -0.5
420
+
421
+ self.qkv = Dense(dim * 3, use_bias=qkv_bias)
422
+ self.attn_drop = Dropout(attn_drop)
423
+ self.proj = Dense(dim)
424
+ self.proj_drop = Dropout(proj_drop)
425
+
426
+ def __call__(self, x):
427
+ B, N, C = x.shape
428
+ qkv = tf.transpose(tf.reshape(self.qkv(x), (B, N, 3, self.num_heads, self.head_dim)), (2, 0, 3, 1, 4))
429
+ q, k, v = tf.unstack(qkv)
430
+
431
+ q = q * self.scale
432
+ attn = tf.matmul(q, tf.transpose(k, (0, 1, 3, 2)))
433
+ attn = tf.nn.softmax(attn)
434
+ attn = self.attn_drop(attn)
435
+ x = tf.matmul(attn, v)
436
+
437
+ x = tf.reshape(tf.transpose(x, (0, 2, 1, 3)), (B, N, C))
438
+ x = self.proj(x)
439
+ x = self.proj_drop(x)
440
+ return x