Ngene787 commited on
Commit
69f0b9b
·
verified ·
1 Parent(s): db82999

Delete unet/models

Browse files
Files changed (1) hide show
  1. unet/models/unet.py +0 -447
unet/models/unet.py DELETED
@@ -1,447 +0,0 @@
1
- from diffusers import UNet2DModel, UNet2DConditionModel
2
-
3
-
4
- class BaseUNet(UNet2DModel):
5
- """Baseline model given. Don't tweak this.
6
- This is technically wrong because it's built for 256 x 256 images.
7
- """
8
-
9
- def __init__(self, config):
10
- super().__init__(
11
- sample_size=config.image_size,
12
- in_channels=3,
13
- out_channels=3,
14
- layers_per_block=2,
15
- block_out_channels=(128, 128, 256, 256, 512, 512),
16
- down_block_types=(
17
- "DownBlock2D", # 256 -> 128
18
- "DownBlock2D", # 128 -> 64
19
- "DownBlock2D", # 64 -> 32
20
- "DownBlock2D", # 32 -> 16
21
- "AttnDownBlock2D", # 16 -> 8
22
- "DownBlock2D", # 8 -> 4
23
- ),
24
- up_block_types=(
25
- "UpBlock2D", # 4 -> 8
26
- "AttnUpBlock2D", # 8 -> 16
27
- "UpBlock2D", # 16 -> 32
28
- "UpBlock2D", # 32 -> 64
29
- "UpBlock2D", # 64 -> 128
30
- "UpBlock2D", # 128 -> 256
31
- ),
32
- )
33
-
34
-
35
- class DDPMUNet(UNet2DModel):
36
- """This class mirrors the DDPM paper. I've tweaked it to work with 128 x 128 images.
37
- We should run some ablations using this class so DO ARGIFY THIS.
38
- Stuff we should try ablating:
39
- - layers_per_block: this is the "depth" mentioned in the paper. We can try increasing it to 4.
40
- - channel width: the paper uses 160, so we can change block_out_channels to (160, 160, 320, 320, 640, 640)
41
- - fix channels-per-head, vary # heads: this is table 2 in the paper (this class fixes it to 64). We can try 32 and 128.
42
- - fix # heads, vary channels-per-head: this is also table 2 in the paper. (this requires us to do something like channel_dim // num_heads), with num_heads being [1, 2, 4, 8].
43
- - remove the attention resolution at 32 and 64: this is the "multi-res attention" ablation in the paper.
44
- - change the "upsample" and "downsample" attention from "resnet" to "default".
45
- - using a "wide" unet by changing the channels to [160, 160, 320, 320, 640, 640]."""
46
-
47
- def __init__(self, config):
48
- if config.multi_res:
49
- # this is basically the same structure as the ADMUNet, making this for ablation purposes.
50
- down_block_types = (
51
- "DownBlock2D", # 128 -> 64
52
- "DownBlock2D", # 64 -> 32
53
- "AttnDownBlock2D", # 32 -> 16
54
- "AttnDownBlock2D", # 16 -> 8
55
- "AttnDownBlock2D", # 8 -> 4
56
- "DownBlock2D", # 4 -> 2
57
- )
58
- up_block_types = (
59
- "UpBlock2D", # 2 -> 4
60
- "AttnUpBlock2D", # 4 -> 8
61
- "AttnUpBlock2D", # 8 -> 16
62
- "AttnUpBlock2D", # 16 -> 32
63
- "UpBlock2D", # 32 -> 64
64
- "UpBlock2D", # 64 -> 128
65
- )
66
- else:
67
- down_block_types = (
68
- "ResnetDownsampleBlock2D", # 128 -> 64
69
- "ResnetDownsampleBlock2D", # 64 -> 32
70
- "ResnetDownsampleBlock2D", # 32 -> 16
71
- "AttnDownBlock2D", # 16 -> 8
72
- "ResnetDownsampleBlock2D", # 8 -> 4
73
- "ResnetDownsampleBlock2D", # 4 -> 2
74
- )
75
- up_block_types = (
76
- "ResnetUpsampleBlock2D", # 2 -> 4
77
- "ResnetUpsampleBlock2D", # 4 -> 8
78
- "AttnUpBlock2D", # 8 -> 16
79
- "ResnetUpsampleBlock2D", # 16 -> 32
80
- "ResnetUpsampleBlock2D", # 32 -> 64
81
- "ResnetUpsampleBlock2D", # 64 -> 128
82
- )
83
- super().__init__(
84
- sample_size=config.image_size,
85
- in_channels=3,
86
- out_channels=3,
87
- layers_per_block=config.layers_per_block,
88
- attention_head_dim=config.attention_head_dim,
89
- # 256 for single head attention at the 16 x 16 resolution.
90
- time_embedding_type="positional",
91
- block_out_channels=tuple(
92
- config.base_channels * m for m in (1, 1, 2, 2, 4, 4)
93
- ),
94
- down_block_types=down_block_types,
95
- up_block_types=up_block_types,
96
- upsample_type=config.downsample_type,
97
- downsample_type=config.upsample_type,
98
- )
99
-
100
-
101
- class ADMUNet(UNet2DModel):
102
- """This is the model used in the ADM paper. DO NOT ARGIFY THIS."""
103
-
104
- def __init__(self, config):
105
- super().__init__(
106
- sample_size=config.image_size,
107
- in_channels=3,
108
- out_channels=3,
109
- layers_per_block=2,
110
- attention_head_dim=64, # this gives varying attention heads for each layer.
111
- downsample_type="resnet", # This gives BigGAN-style residual samplers.
112
- upsample_type="resnet", # same as the above.
113
- resnet_time_scale_shift="scale_shift", # This is the AdaGN portion.
114
- block_out_channels=(128, 128, 256, 256, 512, 512),
115
- down_block_types=(
116
- "DownBlock2D", # 128 -> 64
117
- "AttnDownBlock2D", # 64 -> 32 (2 attention heads)
118
- "AttnDownBlock2D", # 32 -> 16 (4 attention heads)
119
- "AttnDownBlock2D", # 16 -> 8 (8 attention heads)
120
- "DownBlock2D", # 8 -> 4
121
- "DownBlock2D", # 4 -> 2
122
- ),
123
- up_block_types=(
124
- "UpBlock2D", # 2 -> 4
125
- "AttnUpBlock2D", # 4 -> 8 (8 attention heads)
126
- "AttnUpBlock2D", # 8 -> 16 (4 attention heads)
127
- "AttnUpBlock2D", # 16 -> 32 (2 attention heads)
128
- "UpBlock2D", # 32 -> 64
129
- "UpBlock2D", # 64 -> 128
130
- ),
131
- )
132
-
133
-
134
- class ClassConditionedUNet(UNet2DConditionModel):
135
- """For simplicity's sake and a quick proof of concept, we can just use the standard DDPM model and add class embeddings to it."""
136
-
137
- def __init__(self, config):
138
- super().__init__(
139
- sample_size=config.image_size,
140
- in_channels=3,
141
- out_channels=3,
142
- layers_per_block=2,
143
- block_out_channels=(128, 128, 256, 256, 512, 512),
144
- down_block_types=(
145
- "DownBlock2D", # 128 -> 64
146
- "AttnDownBlock2D", # 64 -> 32
147
- "AttnDownBlock2D", # 32 -> 16
148
- "AttnDownBlock2D", # 16 -> 8
149
- "DownBlock2D", # 8 -> 4
150
- "DownBlock2D", # 4 -> 2
151
- ),
152
- up_block_types=(
153
- "UpBlock2D", # 2 -> 4
154
- "AttnUpBlock2D", # 4 -> 8
155
- "AttnUpBlock2D", # 8 -> 16
156
- "AttnUpBlock2D", # 16 -> 32
157
- "UpBlock2D", # 32 -> 64
158
- "UpBlock2D", # 64 -> 128
159
- ),
160
- attention_head_dim=64,
161
- num_class_embeds=2, # 2 classes for male and female.
162
- class_embed_type=None, # keeping this simple since we just have 0 and 1
163
- mid_block_type="UNetMidBlock2D", # disable cross attention
164
- )
165
-
166
-
167
- ARCHITECTURES = {
168
- "base": BaseUNet,
169
- "ddpm": DDPMUNet,
170
- "adm": ADMUNet,
171
- "cond": ClassConditionedUNet,
172
- }
173
-
174
-
175
- def create_unet(config):
176
- try:
177
- cls = ARCHITECTURES[config.unet_variant]
178
- except KeyError:
179
- raise ValueError(
180
- f"Unknown UNet variant {config.unet_variant!r}. "
181
- f"Choose from {list(ARCHITECTURES)}"
182
- )
183
- model = cls(config)
184
- return model
185
-
186
-
187
- _COMPRESS_RATE = 4
188
-
189
-
190
- # TODO: refactor to use Liang's custom implementation.
191
- class BasicUNet(object):
192
- def __init__(
193
- self,
194
- config,
195
- compress_rate=1,
196
- attention_head_dim=8,
197
- layers_per_block=2,
198
- block_num=6,
199
- ):
200
- self.sample_size = int(config.image_size / compress_rate)
201
- self.attention_head_dim = attention_head_dim
202
- self.layers_per_block = layers_per_block
203
- self.block_num = block_num
204
-
205
- def unet_b(self):
206
- model = UNet2DModel(
207
- sample_size=self.sample_size, # the target image resolution
208
- in_channels=3, # the number of input channels, 3 for RGB images
209
- out_channels=3, # the number of output channels
210
- attention_head_dim=self.attention_head_dim,
211
- layers_per_block=self.layers_per_block, # how many ResNet layers to use per UNet block
212
- **self.single_attention_block(),
213
- )
214
- return model
215
-
216
- def unet_l(self):
217
- model = UNet2DModel(
218
- sample_size=self.sample_size, # the target image resolution
219
- in_channels=3, # the number of input channels, 3 for RGB images
220
- out_channels=3, # the number of output channels
221
- attention_head_dim=self.attention_head_dim,
222
- layers_per_block=self.layers_per_block, # how many ResNet layers to use per UNet block
223
- **self.multi_attention_block(),
224
- )
225
- return model
226
-
227
- def unet_xl(self):
228
- model = UNet2DModel(
229
- sample_size=self.sample_size, # the target image resolution
230
- in_channels=3, # the number of input channels, 3 for RGB images
231
- out_channels=3, # the number of output channels
232
- attention_head_dim=self.attention_head_dim,
233
- layers_per_block=self.layers_per_block, # how many ResNet layers to use per UNet block
234
- **self.multi_attention_block_xl(),
235
- )
236
- return model
237
-
238
- def single_attention_block(self):
239
- block_out_channels = [128, 128, 256, 256, 512, 512]
240
- down_block_types = [
241
- "DownBlock2D",
242
- "DownBlock2D",
243
- "DownBlock2D",
244
- "DownBlock2D",
245
- "AttnDownBlock2D",
246
- "DownBlock2D",
247
- ]
248
- up_block_types = [
249
- "UpBlock2D",
250
- "AttnUpBlock2D",
251
- "UpBlock2D",
252
- "UpBlock2D",
253
- "UpBlock2D",
254
- "UpBlock2D",
255
- ]
256
- if self.block_num == 6:
257
- block_out_channels = block_out_channels
258
- down_block_types = down_block_types
259
- up_block_types = up_block_types
260
- elif self.block_num == 8:
261
- block_out_channels = block_out_channels + [1024] * 2
262
- down_block_types = ["DownBlock2D"] * 2 + down_block_types
263
- up_block_types = up_block_types + ["UpBlock2D"] * 2
264
- blocks = {
265
- "block_out_channels": tuple(block_out_channels),
266
- "down_block_types": tuple(down_block_types),
267
- "up_block_types": tuple(up_block_types),
268
- }
269
- return blocks
270
-
271
- def multi_attention_block(self):
272
- block_out_channels = [224, 448, 672, 896]
273
- down_block_types = [
274
- "DownBlock2D",
275
- "AttnDownBlock2D",
276
- "AttnDownBlock2D",
277
- "AttnDownBlock2D",
278
- ]
279
- up_block_types = [
280
- "AttnUpBlock2D",
281
- "AttnUpBlock2D",
282
- "AttnUpBlock2D",
283
- "UpBlock2D",
284
- ]
285
- if self.block_num == 4:
286
- block_out_channels = block_out_channels
287
- down_block_types = down_block_types
288
- up_block_types = up_block_types
289
- elif self.block_num == 5:
290
- block_out_channels = block_out_channels + [1120]
291
- down_block_types = down_block_types + ["AttnDownBlock2D"]
292
- up_block_types = ["AttnUpBlock2D"] + up_block_types
293
- elif self.block_num == 6:
294
- block_out_channels = block_out_channels + [1120, 1344]
295
- down_block_types = down_block_types + ["AttnDownBlock2D"] * 2
296
- up_block_types = ["AttnUpBlock2D"] * 2 + up_block_types
297
- blocks = {
298
- "block_out_channels": tuple(block_out_channels),
299
- "down_block_types": tuple(down_block_types),
300
- "up_block_types": tuple(up_block_types),
301
- }
302
- return blocks
303
-
304
- def multi_attention_block_xl(self):
305
- block_out_channels = [768, 1024, 1280, 1536]
306
- down_block_types = [
307
- "DownBlock2D",
308
- "AttnDownBlock2D",
309
- "AttnDownBlock2D",
310
- "AttnDownBlock2D",
311
- ]
312
- up_block_types = [
313
- "AttnUpBlock2D",
314
- "AttnUpBlock2D",
315
- "AttnUpBlock2D",
316
- "UpBlock2D",
317
- ]
318
- if self.block_num == 6:
319
- block_out_channels = block_out_channels + [1792, 2048]
320
- down_block_types = down_block_types + ["AttnDownBlock2D"] * 2
321
- up_block_types = ["AttnUpBlock2D"] * 2 + up_block_types
322
- blocks = {
323
- "block_out_channels": tuple(block_out_channels),
324
- "down_block_types": tuple(down_block_types),
325
- "up_block_types": tuple(up_block_types),
326
- }
327
- return blocks
328
-
329
-
330
- def unet_b_block_6(config):
331
- return BasicUNet(config, compress_rate=_COMPRESS_RATE).unet_b()
332
-
333
-
334
- def unet_b_block_8(config):
335
- return BasicUNet(config, compress_rate=_COMPRESS_RATE, block_num=8).unet_b()
336
-
337
-
338
- def unet_b_block_6_head_dim_64(config):
339
- return BasicUNet(
340
- config, compress_rate=_COMPRESS_RATE, block_num=6, attention_head_dim=64
341
- ).unet_b()
342
-
343
-
344
- def unet_b_block_8_head_dim_64(config):
345
- return BasicUNet(
346
- config, compress_rate=_COMPRESS_RATE, block_num=8, attention_head_dim=64
347
- ).unet_b()
348
-
349
-
350
- def unet_b_block_8_head_dim_64_layer_4(config):
351
- return BasicUNet(
352
- config,
353
- compress_rate=_COMPRESS_RATE,
354
- block_num=8,
355
- attention_head_dim=64,
356
- layers_per_block=4,
357
- ).unet_b()
358
-
359
-
360
- def unet_l_block_4(config):
361
- return BasicUNet(config, compress_rate=_COMPRESS_RATE, block_num=4).unet_l()
362
-
363
-
364
- def unet_l_block_4_head_dim_64(config):
365
- return BasicUNet(
366
- config, compress_rate=_COMPRESS_RATE, block_num=4, attention_head_dim=64
367
- ).unet_l()
368
-
369
-
370
- def unet_l_block_4_head_dim_64_layer_4(config):
371
- return BasicUNet(
372
- config,
373
- compress_rate=_COMPRESS_RATE,
374
- block_num=4,
375
- attention_head_dim=64,
376
- layers_per_block=4,
377
- ).unet_l()
378
-
379
-
380
- def unet_l_block_5(config):
381
- return BasicUNet(config, compress_rate=_COMPRESS_RATE, block_num=5).unet_l()
382
-
383
-
384
- def unet_l_block_5_head_dim_64(config):
385
- return BasicUNet(
386
- config, compress_rate=_COMPRESS_RATE, block_num=5, attention_head_dim=64
387
- ).unet_l()
388
-
389
-
390
- def unet_l_block_5_head_dim_64_layer_3(config):
391
- return BasicUNet(
392
- config,
393
- compress_rate=_COMPRESS_RATE,
394
- block_num=5,
395
- attention_head_dim=64,
396
- layers_per_block=3,
397
- ).unet_l()
398
-
399
-
400
- def unet_l_block_5_head_dim_64_layer_4(config):
401
- return BasicUNet(
402
- config,
403
- compress_rate=_COMPRESS_RATE,
404
- block_num=5,
405
- attention_head_dim=64,
406
- layers_per_block=4,
407
- ).unet_l()
408
-
409
-
410
- def unet_l_block_6(config):
411
- return BasicUNet(config, compress_rate=_COMPRESS_RATE, block_num=6).unet_l()
412
-
413
-
414
- def unet_l_block_6_head_dim_64(config):
415
- return BasicUNet(
416
- config, compress_rate=_COMPRESS_RATE, block_num=6, attention_head_dim=64
417
- ).unet_l()
418
-
419
-
420
- def unet_l_block_6_head_dim_64_layer_4(config):
421
- return BasicUNet(
422
- config,
423
- compress_rate=_COMPRESS_RATE,
424
- block_num=6,
425
- attention_head_dim=64,
426
- layers_per_block=4,
427
- ).unet_l()
428
-
429
-
430
- def unet_xl_block_6(config):
431
- return BasicUNet(config, compress_rate=_COMPRESS_RATE, block_num=6).unet_xl()
432
-
433
-
434
- def unet_xl_block_6_head_dim_64(config):
435
- return BasicUNet(
436
- config, compress_rate=_COMPRESS_RATE, block_num=6, attention_head_dim=64
437
- ).unet_xl()
438
-
439
-
440
- def unet_xl_block_6_head_dim_64_layer_4(config):
441
- return BasicUNet(
442
- config,
443
- compress_rate=_COMPRESS_RATE,
444
- block_num=6,
445
- attention_head_dim=64,
446
- layers_per_block=4,
447
- ).unet_xl()