Ngene787 commited on
Commit
0172829
·
verified ·
1 Parent(s): de16d70

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. unet/models/unet.py +447 -0
unet/models/unet.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import UNet2DModel, UNet2DConditionModel
2
+
3
+
4
+ class BaseUNet(UNet2DModel):
5
+ """Baseline model given. Don't tweak this.
6
+ This is technically wrong because it's built for 256 x 256 images.
7
+ """
8
+
9
+ def __init__(self, config):
10
+ super().__init__(
11
+ sample_size=config.image_size,
12
+ in_channels=3,
13
+ out_channels=3,
14
+ layers_per_block=2,
15
+ block_out_channels=(128, 128, 256, 256, 512, 512),
16
+ down_block_types=(
17
+ "DownBlock2D", # 256 -> 128
18
+ "DownBlock2D", # 128 -> 64
19
+ "DownBlock2D", # 64 -> 32
20
+ "DownBlock2D", # 32 -> 16
21
+ "AttnDownBlock2D", # 16 -> 8
22
+ "DownBlock2D", # 8 -> 4
23
+ ),
24
+ up_block_types=(
25
+ "UpBlock2D", # 4 -> 8
26
+ "AttnUpBlock2D", # 8 -> 16
27
+ "UpBlock2D", # 16 -> 32
28
+ "UpBlock2D", # 32 -> 64
29
+ "UpBlock2D", # 64 -> 128
30
+ "UpBlock2D", # 128 -> 256
31
+ ),
32
+ )
33
+
34
+
35
+ class DDPMUNet(UNet2DModel):
36
+ """This class mirrors the DDPM paper. I've tweaked it to work with 128 x 128 images.
37
+ We should run some ablations using this class so DO ARGIFY THIS.
38
+ Stuff we should try ablating:
39
+ - layers_per_block: this is the "depth" mentioned in the paper. We can try increasing it to 4.
40
+ - channel width: the paper uses 160, so we can change block_out_channels to (160, 160, 320, 320, 640, 640)
41
+ - fix channels-per-head, vary # heads: this is table 2 in the paper (this class fixes it to 64). We can try 32 and 128.
42
+ - fix # heads, vary channels-per-head: this is also table 2 in the paper. (this requires us to do something like channel_dim // num_heads), with num_heads being [1, 2, 4, 8].
43
+ - remove the attention resolution at 32 and 64: this is the "multi-res attention" ablation in the paper.
44
+ - change the "upsample" and "downsample" attention from "resnet" to "default".
45
+ - using a "wide" unet by changing the channels to [160, 160, 320, 320, 640, 640]."""
46
+
47
+ def __init__(self, config):
48
+ if config.multi_res:
49
+ # this is basically the same structure as the ADMUNet, making this for ablation purposes.
50
+ down_block_types = (
51
+ "DownBlock2D", # 128 -> 64
52
+ "DownBlock2D", # 64 -> 32
53
+ "AttnDownBlock2D", # 32 -> 16
54
+ "AttnDownBlock2D", # 16 -> 8
55
+ "AttnDownBlock2D", # 8 -> 4
56
+ "DownBlock2D", # 4 -> 2
57
+ )
58
+ up_block_types = (
59
+ "UpBlock2D", # 2 -> 4
60
+ "AttnUpBlock2D", # 4 -> 8
61
+ "AttnUpBlock2D", # 8 -> 16
62
+ "AttnUpBlock2D", # 16 -> 32
63
+ "UpBlock2D", # 32 -> 64
64
+ "UpBlock2D", # 64 -> 128
65
+ )
66
+ else:
67
+ down_block_types = (
68
+ "ResnetDownsampleBlock2D", # 128 -> 64
69
+ "ResnetDownsampleBlock2D", # 64 -> 32
70
+ "ResnetDownsampleBlock2D", # 32 -> 16
71
+ "AttnDownBlock2D", # 16 -> 8
72
+ "ResnetDownsampleBlock2D", # 8 -> 4
73
+ "ResnetDownsampleBlock2D", # 4 -> 2
74
+ )
75
+ up_block_types = (
76
+ "ResnetUpsampleBlock2D", # 2 -> 4
77
+ "ResnetUpsampleBlock2D", # 4 -> 8
78
+ "AttnUpBlock2D", # 8 -> 16
79
+ "ResnetUpsampleBlock2D", # 16 -> 32
80
+ "ResnetUpsampleBlock2D", # 32 -> 64
81
+ "ResnetUpsampleBlock2D", # 64 -> 128
82
+ )
83
+ super().__init__(
84
+ sample_size=config.image_size,
85
+ in_channels=3,
86
+ out_channels=3,
87
+ layers_per_block=config.layers_per_block,
88
+ attention_head_dim=config.attention_head_dim,
89
+ # 256 for single head attention at the 16 x 16 resolution.
90
+ time_embedding_type="positional",
91
+ block_out_channels=tuple(
92
+ config.base_channels * m for m in (1, 1, 2, 2, 4, 4)
93
+ ),
94
+ down_block_types=down_block_types,
95
+ up_block_types=up_block_types,
96
+ upsample_type=config.downsample_type,
97
+ downsample_type=config.upsample_type,
98
+ )
99
+
100
+
101
+ class ADMUNet(UNet2DModel):
102
+ """This is the model used in the ADM paper. DO NOT ARGIFY THIS."""
103
+
104
+ def __init__(self, config):
105
+ super().__init__(
106
+ sample_size=config.image_size,
107
+ in_channels=3,
108
+ out_channels=3,
109
+ layers_per_block=2,
110
+ attention_head_dim=64, # this gives varying attention heads for each layer.
111
+ downsample_type="resnet", # This gives BigGAN-style residual samplers.
112
+ upsample_type="resnet", # same as the above.
113
+ resnet_time_scale_shift="scale_shift", # This is the AdaGN portion.
114
+ block_out_channels=(128, 128, 256, 256, 512, 512),
115
+ down_block_types=(
116
+ "DownBlock2D", # 128 -> 64
117
+ "AttnDownBlock2D", # 64 -> 32 (2 attention heads)
118
+ "AttnDownBlock2D", # 32 -> 16 (4 attention heads)
119
+ "AttnDownBlock2D", # 16 -> 8 (8 attention heads)
120
+ "DownBlock2D", # 8 -> 4
121
+ "DownBlock2D", # 4 -> 2
122
+ ),
123
+ up_block_types=(
124
+ "UpBlock2D", # 2 -> 4
125
+ "AttnUpBlock2D", # 4 -> 8 (8 attention heads)
126
+ "AttnUpBlock2D", # 8 -> 16 (4 attention heads)
127
+ "AttnUpBlock2D", # 16 -> 32 (2 attention heads)
128
+ "UpBlock2D", # 32 -> 64
129
+ "UpBlock2D", # 64 -> 128
130
+ ),
131
+ )
132
+
133
+
134
+ class ClassConditionedUNet(UNet2DConditionModel):
135
+ """For simplicity's sake and a quick proof of concept, we can just use the standard DDPM model and add class embeddings to it."""
136
+
137
+ def __init__(self, config):
138
+ super().__init__(
139
+ sample_size=config.image_size,
140
+ in_channels=3,
141
+ out_channels=3,
142
+ layers_per_block=2,
143
+ block_out_channels=(128, 128, 256, 256, 512, 512),
144
+ down_block_types=(
145
+ "DownBlock2D", # 128 -> 64
146
+ "AttnDownBlock2D", # 64 -> 32
147
+ "AttnDownBlock2D", # 32 -> 16
148
+ "AttnDownBlock2D", # 16 -> 8
149
+ "DownBlock2D", # 8 -> 4
150
+ "DownBlock2D", # 4 -> 2
151
+ ),
152
+ up_block_types=(
153
+ "UpBlock2D", # 2 -> 4
154
+ "AttnUpBlock2D", # 4 -> 8
155
+ "AttnUpBlock2D", # 8 -> 16
156
+ "AttnUpBlock2D", # 16 -> 32
157
+ "UpBlock2D", # 32 -> 64
158
+ "UpBlock2D", # 64 -> 128
159
+ ),
160
+ attention_head_dim=64,
161
+ num_class_embeds=2, # 2 classes for male and female.
162
+ class_embed_type=None, # keeping this simple since we just have 0 and 1
163
+ mid_block_type="UNetMidBlock2D", # disable cross attention
164
+ )
165
+
166
+
167
+ ARCHITECTURES = {
168
+ "base": BaseUNet,
169
+ "ddpm": DDPMUNet,
170
+ "adm": ADMUNet,
171
+ "cond": ClassConditionedUNet,
172
+ }
173
+
174
+
175
+ def create_unet(config):
176
+ try:
177
+ cls = ARCHITECTURES[config.unet_variant]
178
+ except KeyError:
179
+ raise ValueError(
180
+ f"Unknown UNet variant {config.unet_variant!r}. "
181
+ f"Choose from {list(ARCHITECTURES)}"
182
+ )
183
+ model = cls(config)
184
+ return model
185
+
186
+
187
+ _COMPRESS_RATE = 4
188
+
189
+
190
+ # TODO: refactor to use Liang's custom implementation.
191
+ class BasicUNet(object):
192
+ def __init__(
193
+ self,
194
+ config,
195
+ compress_rate=1,
196
+ attention_head_dim=8,
197
+ layers_per_block=2,
198
+ block_num=6,
199
+ ):
200
+ self.sample_size = int(config.image_size / compress_rate)
201
+ self.attention_head_dim = attention_head_dim
202
+ self.layers_per_block = layers_per_block
203
+ self.block_num = block_num
204
+
205
+ def unet_b(self):
206
+ model = UNet2DModel(
207
+ sample_size=self.sample_size, # the target image resolution
208
+ in_channels=3, # the number of input channels, 3 for RGB images
209
+ out_channels=3, # the number of output channels
210
+ attention_head_dim=self.attention_head_dim,
211
+ layers_per_block=self.layers_per_block, # how many ResNet layers to use per UNet block
212
+ **self.single_attention_block(),
213
+ )
214
+ return model
215
+
216
+ def unet_l(self):
217
+ model = UNet2DModel(
218
+ sample_size=self.sample_size, # the target image resolution
219
+ in_channels=3, # the number of input channels, 3 for RGB images
220
+ out_channels=3, # the number of output channels
221
+ attention_head_dim=self.attention_head_dim,
222
+ layers_per_block=self.layers_per_block, # how many ResNet layers to use per UNet block
223
+ **self.multi_attention_block(),
224
+ )
225
+ return model
226
+
227
+ def unet_xl(self):
228
+ model = UNet2DModel(
229
+ sample_size=self.sample_size, # the target image resolution
230
+ in_channels=3, # the number of input channels, 3 for RGB images
231
+ out_channels=3, # the number of output channels
232
+ attention_head_dim=self.attention_head_dim,
233
+ layers_per_block=self.layers_per_block, # how many ResNet layers to use per UNet block
234
+ **self.multi_attention_block_xl(),
235
+ )
236
+ return model
237
+
238
+ def single_attention_block(self):
239
+ block_out_channels = [128, 128, 256, 256, 512, 512]
240
+ down_block_types = [
241
+ "DownBlock2D",
242
+ "DownBlock2D",
243
+ "DownBlock2D",
244
+ "DownBlock2D",
245
+ "AttnDownBlock2D",
246
+ "DownBlock2D",
247
+ ]
248
+ up_block_types = [
249
+ "UpBlock2D",
250
+ "AttnUpBlock2D",
251
+ "UpBlock2D",
252
+ "UpBlock2D",
253
+ "UpBlock2D",
254
+ "UpBlock2D",
255
+ ]
256
+ if self.block_num == 6:
257
+ block_out_channels = block_out_channels
258
+ down_block_types = down_block_types
259
+ up_block_types = up_block_types
260
+ elif self.block_num == 8:
261
+ block_out_channels = block_out_channels + [1024] * 2
262
+ down_block_types = ["DownBlock2D"] * 2 + down_block_types
263
+ up_block_types = up_block_types + ["UpBlock2D"] * 2
264
+ blocks = {
265
+ "block_out_channels": tuple(block_out_channels),
266
+ "down_block_types": tuple(down_block_types),
267
+ "up_block_types": tuple(up_block_types),
268
+ }
269
+ return blocks
270
+
271
+ def multi_attention_block(self):
272
+ block_out_channels = [224, 448, 672, 896]
273
+ down_block_types = [
274
+ "DownBlock2D",
275
+ "AttnDownBlock2D",
276
+ "AttnDownBlock2D",
277
+ "AttnDownBlock2D",
278
+ ]
279
+ up_block_types = [
280
+ "AttnUpBlock2D",
281
+ "AttnUpBlock2D",
282
+ "AttnUpBlock2D",
283
+ "UpBlock2D",
284
+ ]
285
+ if self.block_num == 4:
286
+ block_out_channels = block_out_channels
287
+ down_block_types = down_block_types
288
+ up_block_types = up_block_types
289
+ elif self.block_num == 5:
290
+ block_out_channels = block_out_channels + [1120]
291
+ down_block_types = down_block_types + ["AttnDownBlock2D"]
292
+ up_block_types = ["AttnUpBlock2D"] + up_block_types
293
+ elif self.block_num == 6:
294
+ block_out_channels = block_out_channels + [1120, 1344]
295
+ down_block_types = down_block_types + ["AttnDownBlock2D"] * 2
296
+ up_block_types = ["AttnUpBlock2D"] * 2 + up_block_types
297
+ blocks = {
298
+ "block_out_channels": tuple(block_out_channels),
299
+ "down_block_types": tuple(down_block_types),
300
+ "up_block_types": tuple(up_block_types),
301
+ }
302
+ return blocks
303
+
304
+ def multi_attention_block_xl(self):
305
+ block_out_channels = [768, 1024, 1280, 1536]
306
+ down_block_types = [
307
+ "DownBlock2D",
308
+ "AttnDownBlock2D",
309
+ "AttnDownBlock2D",
310
+ "AttnDownBlock2D",
311
+ ]
312
+ up_block_types = [
313
+ "AttnUpBlock2D",
314
+ "AttnUpBlock2D",
315
+ "AttnUpBlock2D",
316
+ "UpBlock2D",
317
+ ]
318
+ if self.block_num == 6:
319
+ block_out_channels = block_out_channels + [1792, 2048]
320
+ down_block_types = down_block_types + ["AttnDownBlock2D"] * 2
321
+ up_block_types = ["AttnUpBlock2D"] * 2 + up_block_types
322
+ blocks = {
323
+ "block_out_channels": tuple(block_out_channels),
324
+ "down_block_types": tuple(down_block_types),
325
+ "up_block_types": tuple(up_block_types),
326
+ }
327
+ return blocks
328
+
329
+
330
+ def unet_b_block_6(config):
331
+ return BasicUNet(config, compress_rate=_COMPRESS_RATE).unet_b()
332
+
333
+
334
+ def unet_b_block_8(config):
335
+ return BasicUNet(config, compress_rate=_COMPRESS_RATE, block_num=8).unet_b()
336
+
337
+
338
+ def unet_b_block_6_head_dim_64(config):
339
+ return BasicUNet(
340
+ config, compress_rate=_COMPRESS_RATE, block_num=6, attention_head_dim=64
341
+ ).unet_b()
342
+
343
+
344
+ def unet_b_block_8_head_dim_64(config):
345
+ return BasicUNet(
346
+ config, compress_rate=_COMPRESS_RATE, block_num=8, attention_head_dim=64
347
+ ).unet_b()
348
+
349
+
350
+ def unet_b_block_8_head_dim_64_layer_4(config):
351
+ return BasicUNet(
352
+ config,
353
+ compress_rate=_COMPRESS_RATE,
354
+ block_num=8,
355
+ attention_head_dim=64,
356
+ layers_per_block=4,
357
+ ).unet_b()
358
+
359
+
360
+ def unet_l_block_4(config):
361
+ return BasicUNet(config, compress_rate=_COMPRESS_RATE, block_num=4).unet_l()
362
+
363
+
364
+ def unet_l_block_4_head_dim_64(config):
365
+ return BasicUNet(
366
+ config, compress_rate=_COMPRESS_RATE, block_num=4, attention_head_dim=64
367
+ ).unet_l()
368
+
369
+
370
+ def unet_l_block_4_head_dim_64_layer_4(config):
371
+ return BasicUNet(
372
+ config,
373
+ compress_rate=_COMPRESS_RATE,
374
+ block_num=4,
375
+ attention_head_dim=64,
376
+ layers_per_block=4,
377
+ ).unet_l()
378
+
379
+
380
+ def unet_l_block_5(config):
381
+ return BasicUNet(config, compress_rate=_COMPRESS_RATE, block_num=5).unet_l()
382
+
383
+
384
+ def unet_l_block_5_head_dim_64(config):
385
+ return BasicUNet(
386
+ config, compress_rate=_COMPRESS_RATE, block_num=5, attention_head_dim=64
387
+ ).unet_l()
388
+
389
+
390
+ def unet_l_block_5_head_dim_64_layer_3(config):
391
+ return BasicUNet(
392
+ config,
393
+ compress_rate=_COMPRESS_RATE,
394
+ block_num=5,
395
+ attention_head_dim=64,
396
+ layers_per_block=3,
397
+ ).unet_l()
398
+
399
+
400
+ def unet_l_block_5_head_dim_64_layer_4(config):
401
+ return BasicUNet(
402
+ config,
403
+ compress_rate=_COMPRESS_RATE,
404
+ block_num=5,
405
+ attention_head_dim=64,
406
+ layers_per_block=4,
407
+ ).unet_l()
408
+
409
+
410
+ def unet_l_block_6(config):
411
+ return BasicUNet(config, compress_rate=_COMPRESS_RATE, block_num=6).unet_l()
412
+
413
+
414
+ def unet_l_block_6_head_dim_64(config):
415
+ return BasicUNet(
416
+ config, compress_rate=_COMPRESS_RATE, block_num=6, attention_head_dim=64
417
+ ).unet_l()
418
+
419
+
420
+ def unet_l_block_6_head_dim_64_layer_4(config):
421
+ return BasicUNet(
422
+ config,
423
+ compress_rate=_COMPRESS_RATE,
424
+ block_num=6,
425
+ attention_head_dim=64,
426
+ layers_per_block=4,
427
+ ).unet_l()
428
+
429
+
430
+ def unet_xl_block_6(config):
431
+ return BasicUNet(config, compress_rate=_COMPRESS_RATE, block_num=6).unet_xl()
432
+
433
+
434
+ def unet_xl_block_6_head_dim_64(config):
435
+ return BasicUNet(
436
+ config, compress_rate=_COMPRESS_RATE, block_num=6, attention_head_dim=64
437
+ ).unet_xl()
438
+
439
+
440
+ def unet_xl_block_6_head_dim_64_layer_4(config):
441
+ return BasicUNet(
442
+ config,
443
+ compress_rate=_COMPRESS_RATE,
444
+ block_num=6,
445
+ attention_head_dim=64,
446
+ layers_per_block=4,
447
+ ).unet_xl()