BiliSakura commited on
Commit
9e64921
·
verified ·
1 Parent(s): 044c88b

Delete ADM-G-512/classifier/modeling_adm.py

Browse files
Files changed (1) hide show
  1. ADM-G-512/classifier/modeling_adm.py +0 -772
ADM-G-512/classifier/modeling_adm.py DELETED
@@ -1,772 +0,0 @@
1
- import math
2
- from abc import abstractmethod
3
- from typing import Optional
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from torch.utils.checkpoint import checkpoint as torch_checkpoint
9
-
10
-
11
- NUM_CLASSES = 1000
12
-
13
-
14
- def conv_nd(dims: int, *args, **kwargs):
15
- if dims == 1:
16
- return nn.Conv1d(*args, **kwargs)
17
- if dims == 2:
18
- return nn.Conv2d(*args, **kwargs)
19
- if dims == 3:
20
- return nn.Conv3d(*args, **kwargs)
21
- raise ValueError(f"unsupported dimensions: {dims}")
22
-
23
-
24
- def linear(*args, **kwargs):
25
- return nn.Linear(*args, **kwargs)
26
-
27
-
28
- def avg_pool_nd(dims: int, *args, **kwargs):
29
- if dims == 1:
30
- return nn.AvgPool1d(*args, **kwargs)
31
- if dims == 2:
32
- return nn.AvgPool2d(*args, **kwargs)
33
- if dims == 3:
34
- return nn.AvgPool3d(*args, **kwargs)
35
- raise ValueError(f"unsupported dimensions: {dims}")
36
-
37
-
38
- class GroupNorm32(nn.GroupNorm):
39
- def forward(self, x):
40
- return super().forward(x.float()).type(x.dtype)
41
-
42
-
43
- def normalization(channels: int):
44
- return GroupNorm32(32, channels)
45
-
46
-
47
- def zero_module(module: nn.Module):
48
- for p in module.parameters():
49
- p.detach().zero_()
50
- return module
51
-
52
-
53
- def timestep_embedding(timesteps: torch.Tensor, dim: int, max_period: int = 10000):
54
- half = dim // 2
55
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
56
- device=timesteps.device
57
- )
58
- args = timesteps[:, None].float() * freqs[None]
59
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
- if dim % 2:
61
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
62
- return embedding
63
-
64
-
65
- def convert_module_to_f16(module: nn.Module):
66
- if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
67
- module.weight.data = module.weight.data.half()
68
- if module.bias is not None:
69
- module.bias.data = module.bias.data.half()
70
-
71
-
72
- def convert_module_to_f32(module: nn.Module):
73
- if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
74
- module.weight.data = module.weight.data.float()
75
- if module.bias is not None:
76
- module.bias.data = module.bias.data.float()
77
-
78
-
79
- class TimestepBlock(nn.Module):
80
- @abstractmethod
81
- def forward(self, x, emb):
82
- raise NotImplementedError
83
-
84
-
85
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
86
- def forward(self, x, emb):
87
- for layer in self:
88
- if isinstance(layer, TimestepBlock):
89
- x = layer(x, emb)
90
- else:
91
- x = layer(x)
92
- return x
93
-
94
-
95
- class Upsample(nn.Module):
96
- def __init__(self, channels, use_conv, dims=2, out_channels=None):
97
- super().__init__()
98
- self.channels = channels
99
- self.out_channels = out_channels or channels
100
- self.use_conv = use_conv
101
- self.dims = dims
102
- if use_conv:
103
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
104
-
105
- def forward(self, x):
106
- assert x.shape[1] == self.channels
107
- if self.dims == 3:
108
- x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
109
- else:
110
- x = F.interpolate(x, scale_factor=2, mode="nearest")
111
- if self.use_conv:
112
- x = self.conv(x)
113
- return x
114
-
115
-
116
- class Downsample(nn.Module):
117
- def __init__(self, channels, use_conv, dims=2, out_channels=None):
118
- super().__init__()
119
- self.channels = channels
120
- self.out_channels = out_channels or channels
121
- self.use_conv = use_conv
122
- stride = 2 if dims != 3 else (1, 2, 2)
123
- if use_conv:
124
- self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1)
125
- else:
126
- assert self.channels == self.out_channels
127
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
128
-
129
- def forward(self, x):
130
- assert x.shape[1] == self.channels
131
- return self.op(x)
132
-
133
-
134
- class ResBlock(TimestepBlock):
135
- def __init__(
136
- self,
137
- channels,
138
- emb_channels,
139
- dropout,
140
- out_channels=None,
141
- use_conv=False,
142
- use_scale_shift_norm=False,
143
- dims=2,
144
- use_checkpoint=False,
145
- up=False,
146
- down=False,
147
- ):
148
- super().__init__()
149
- self.channels = channels
150
- self.out_channels = out_channels or channels
151
- self.use_checkpoint = use_checkpoint
152
- self.use_scale_shift_norm = use_scale_shift_norm
153
- self.in_layers = nn.Sequential(
154
- normalization(channels),
155
- nn.SiLU(),
156
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
157
- )
158
-
159
- self.updown = up or down
160
- if up:
161
- self.h_upd = Upsample(channels, False, dims)
162
- self.x_upd = Upsample(channels, False, dims)
163
- elif down:
164
- self.h_upd = Downsample(channels, False, dims)
165
- self.x_upd = Downsample(channels, False, dims)
166
- else:
167
- self.h_upd = self.x_upd = nn.Identity()
168
-
169
- self.emb_layers = nn.Sequential(
170
- nn.SiLU(),
171
- linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels),
172
- )
173
- self.out_layers = nn.Sequential(
174
- normalization(self.out_channels),
175
- nn.SiLU(),
176
- nn.Dropout(p=dropout),
177
- zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
178
- )
179
-
180
- if self.out_channels == channels:
181
- self.skip_connection = nn.Identity()
182
- elif use_conv:
183
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
184
- else:
185
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
186
-
187
- def forward(self, x, emb):
188
- if self.use_checkpoint and x.requires_grad:
189
- return torch_checkpoint(self._forward, x, emb, use_reentrant=False)
190
- return self._forward(x, emb)
191
-
192
- def _forward(self, x, emb):
193
- if self.updown:
194
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
195
- h = in_rest(x)
196
- h = self.h_upd(h)
197
- x = self.x_upd(x)
198
- h = in_conv(h)
199
- else:
200
- h = self.in_layers(x)
201
-
202
- emb_out = self.emb_layers(emb).type(h.dtype)
203
- while len(emb_out.shape) < len(h.shape):
204
- emb_out = emb_out[..., None]
205
- if self.use_scale_shift_norm:
206
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
207
- scale, shift = torch.chunk(emb_out, 2, dim=1)
208
- h = out_norm(h) * (1 + scale) + shift
209
- h = out_rest(h)
210
- else:
211
- h = h + emb_out
212
- h = self.out_layers(h)
213
- return self.skip_connection(x) + h
214
-
215
-
216
- class QKVAttentionLegacy(nn.Module):
217
- def __init__(self, n_heads):
218
- super().__init__()
219
- self.n_heads = n_heads
220
-
221
- def forward(self, qkv):
222
- bs, width, length = qkv.shape
223
- assert width % (3 * self.n_heads) == 0
224
- ch = width // (3 * self.n_heads)
225
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
226
- scale = 1 / math.sqrt(math.sqrt(ch))
227
- weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)
228
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
229
- a = torch.einsum("bts,bcs->bct", weight, v)
230
- return a.reshape(bs, -1, length)
231
-
232
-
233
- class QKVAttention(nn.Module):
234
- def __init__(self, n_heads):
235
- super().__init__()
236
- self.n_heads = n_heads
237
-
238
- def forward(self, qkv):
239
- bs, width, length = qkv.shape
240
- assert width % (3 * self.n_heads) == 0
241
- ch = width // (3 * self.n_heads)
242
- q, k, v = qkv.chunk(3, dim=1)
243
- scale = 1 / math.sqrt(math.sqrt(ch))
244
- weight = torch.einsum(
245
- "bct,bcs->bts",
246
- (q * scale).view(bs * self.n_heads, ch, length),
247
- (k * scale).view(bs * self.n_heads, ch, length),
248
- )
249
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
250
- a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
251
- return a.reshape(bs, -1, length)
252
-
253
-
254
- class AttentionBlock(nn.Module):
255
- def __init__(
256
- self,
257
- channels,
258
- num_heads=1,
259
- num_head_channels=-1,
260
- use_checkpoint=False,
261
- use_new_attention_order=False,
262
- ):
263
- super().__init__()
264
- if num_head_channels == -1:
265
- self.num_heads = num_heads
266
- else:
267
- assert channels % num_head_channels == 0
268
- self.num_heads = channels // num_head_channels
269
- self.use_checkpoint = use_checkpoint
270
- self.norm = normalization(channels)
271
- self.qkv = conv_nd(1, channels, channels * 3, 1)
272
- self.attention = QKVAttention(self.num_heads) if use_new_attention_order else QKVAttentionLegacy(self.num_heads)
273
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
274
-
275
- def forward(self, x):
276
- if self.use_checkpoint and x.requires_grad:
277
- return torch_checkpoint(self._forward, x, use_reentrant=False)
278
- return self._forward(x)
279
-
280
- def _forward(self, x):
281
- b, c, *spatial = x.shape
282
- x = x.reshape(b, c, -1)
283
- qkv = self.qkv(self.norm(x))
284
- h = self.attention(qkv)
285
- h = self.proj_out(h)
286
- return (x + h).reshape(b, c, *spatial)
287
-
288
-
289
- class AttentionPool2d(nn.Module):
290
- """CLIP-style attention pooling used by ADM noisy classifiers."""
291
-
292
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None):
293
- super().__init__()
294
- self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
295
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
296
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
297
- self.num_heads = embed_dim // num_heads_channels
298
- self.attention = QKVAttention(self.num_heads)
299
-
300
- def forward(self, x):
301
- b, c, *_spatial = x.shape
302
- x = x.reshape(b, c, -1)
303
- x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
304
- x = x + self.positional_embedding[None, :, :].to(x.dtype)
305
- x = self.qkv_proj(x)
306
- x = self.attention(x)
307
- x = self.c_proj(x)
308
- return x[:, :, 0]
309
-
310
-
311
- class EncoderUNetModel(nn.Module):
312
- """Noisy image classifier backbone for ADM-G (classifier guidance)."""
313
-
314
- def __init__(
315
- self,
316
- image_size,
317
- in_channels,
318
- model_channels,
319
- out_channels,
320
- num_res_blocks,
321
- attention_resolutions,
322
- dropout=0,
323
- channel_mult=(1, 2, 4, 8),
324
- conv_resample=True,
325
- dims=2,
326
- use_checkpoint=False,
327
- use_fp16=False,
328
- num_heads=1,
329
- num_head_channels=-1,
330
- use_scale_shift_norm=False,
331
- resblock_updown=False,
332
- use_new_attention_order=False,
333
- pool="adaptive",
334
- ):
335
- super().__init__()
336
-
337
- self.in_channels = in_channels
338
- self.model_channels = model_channels
339
- self.out_channels = out_channels
340
- self.num_res_blocks = num_res_blocks
341
- self.dropout = dropout
342
- self.channel_mult = channel_mult
343
- self.conv_resample = conv_resample
344
- self.use_checkpoint = use_checkpoint
345
- self.dtype = torch.float16 if use_fp16 else torch.float32
346
- self.num_heads = num_heads
347
- self.num_head_channels = num_head_channels
348
-
349
- time_embed_dim = model_channels * 4
350
- self.time_embed = nn.Sequential(
351
- linear(model_channels, time_embed_dim),
352
- nn.SiLU(),
353
- linear(time_embed_dim, time_embed_dim),
354
- )
355
-
356
- ch = int(channel_mult[0] * model_channels)
357
- self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
358
- self._feature_size = ch
359
- input_block_chans = [ch]
360
- ds = 1
361
- for level, mult in enumerate(channel_mult):
362
- for _ in range(num_res_blocks):
363
- layers = [
364
- ResBlock(
365
- ch,
366
- time_embed_dim,
367
- dropout,
368
- out_channels=int(mult * model_channels),
369
- dims=dims,
370
- use_checkpoint=use_checkpoint,
371
- use_scale_shift_norm=use_scale_shift_norm,
372
- )
373
- ]
374
- ch = int(mult * model_channels)
375
- if ds in attention_resolutions:
376
- layers.append(
377
- AttentionBlock(
378
- ch,
379
- use_checkpoint=use_checkpoint,
380
- num_heads=num_heads,
381
- num_head_channels=num_head_channels,
382
- use_new_attention_order=use_new_attention_order,
383
- )
384
- )
385
- self.input_blocks.append(TimestepEmbedSequential(*layers))
386
- self._feature_size += ch
387
- input_block_chans.append(ch)
388
- if level != len(channel_mult) - 1:
389
- out_ch = ch
390
- self.input_blocks.append(
391
- TimestepEmbedSequential(
392
- ResBlock(
393
- ch,
394
- time_embed_dim,
395
- dropout,
396
- out_channels=out_ch,
397
- dims=dims,
398
- use_checkpoint=use_checkpoint,
399
- use_scale_shift_norm=use_scale_shift_norm,
400
- down=True,
401
- )
402
- if resblock_updown
403
- else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
404
- )
405
- )
406
- ch = out_ch
407
- input_block_chans.append(ch)
408
- ds *= 2
409
- self._feature_size += ch
410
-
411
- self.middle_block = TimestepEmbedSequential(
412
- ResBlock(
413
- ch,
414
- time_embed_dim,
415
- dropout,
416
- dims=dims,
417
- use_checkpoint=use_checkpoint,
418
- use_scale_shift_norm=use_scale_shift_norm,
419
- ),
420
- AttentionBlock(
421
- ch,
422
- use_checkpoint=use_checkpoint,
423
- num_heads=num_heads,
424
- num_head_channels=num_head_channels,
425
- use_new_attention_order=use_new_attention_order,
426
- ),
427
- ResBlock(
428
- ch,
429
- time_embed_dim,
430
- dropout,
431
- dims=dims,
432
- use_checkpoint=use_checkpoint,
433
- use_scale_shift_norm=use_scale_shift_norm,
434
- ),
435
- )
436
- self._feature_size += ch
437
- self.pool = pool
438
- if pool == "adaptive":
439
- self.out = nn.Sequential(
440
- normalization(ch),
441
- nn.SiLU(),
442
- nn.AdaptiveAvgPool2d((1, 1)),
443
- zero_module(conv_nd(dims, ch, out_channels, 1)),
444
- nn.Flatten(),
445
- )
446
- elif pool == "attention":
447
- assert num_head_channels != -1
448
- self.out = nn.Sequential(
449
- normalization(ch),
450
- nn.SiLU(),
451
- AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
452
- )
453
- elif pool == "spatial":
454
- self.out = nn.Sequential(
455
- nn.Linear(self._feature_size, 2048),
456
- nn.ReLU(),
457
- nn.Linear(2048, out_channels),
458
- )
459
- elif pool == "spatial_v2":
460
- self.out = nn.Sequential(
461
- nn.Linear(self._feature_size, 2048),
462
- normalization(2048),
463
- nn.SiLU(),
464
- nn.Linear(2048, out_channels),
465
- )
466
- else:
467
- raise NotImplementedError(f"Unexpected {pool} pooling")
468
-
469
- def convert_to_fp16(self):
470
- self.input_blocks.apply(convert_module_to_f16)
471
- self.middle_block.apply(convert_module_to_f16)
472
-
473
- def convert_to_fp32(self):
474
- self.input_blocks.apply(convert_module_to_f32)
475
- self.middle_block.apply(convert_module_to_f32)
476
-
477
- def forward(self, x, timesteps):
478
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
479
- results = []
480
- h = x.type(self.dtype)
481
- for module in self.input_blocks:
482
- h = module(h, emb)
483
- if self.pool.startswith("spatial"):
484
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
485
- h = self.middle_block(h, emb)
486
- if self.pool.startswith("spatial"):
487
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
488
- h = torch.cat(results, dim=-1)
489
- return self.out(h)
490
- h = h.type(x.dtype)
491
- return self.out(h)
492
-
493
-
494
- class UNetModel(nn.Module):
495
- def __init__(
496
- self,
497
- image_size,
498
- in_channels,
499
- model_channels,
500
- out_channels,
501
- num_res_blocks,
502
- attention_resolutions,
503
- dropout=0,
504
- channel_mult=(1, 2, 4, 8),
505
- conv_resample=True,
506
- dims=2,
507
- num_classes=None,
508
- use_checkpoint=False,
509
- use_fp16=False,
510
- num_heads=1,
511
- num_head_channels=-1,
512
- num_heads_upsample=-1,
513
- use_scale_shift_norm=False,
514
- resblock_updown=False,
515
- use_new_attention_order=False,
516
- ):
517
- super().__init__()
518
- if num_heads_upsample == -1:
519
- num_heads_upsample = num_heads
520
-
521
- self.model_channels = model_channels
522
- self.num_classes = num_classes
523
- self.dtype = torch.float16 if use_fp16 else torch.float32
524
-
525
- time_embed_dim = model_channels * 4
526
- self.time_embed = nn.Sequential(
527
- linear(model_channels, time_embed_dim),
528
- nn.SiLU(),
529
- linear(time_embed_dim, time_embed_dim),
530
- )
531
- if self.num_classes is not None:
532
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
533
-
534
- ch = input_ch = int(channel_mult[0] * model_channels)
535
- self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
536
- input_block_chans = [ch]
537
- ds = 1
538
- for level, mult in enumerate(channel_mult):
539
- for _ in range(num_res_blocks):
540
- layers = [
541
- ResBlock(
542
- ch,
543
- time_embed_dim,
544
- dropout,
545
- out_channels=int(mult * model_channels),
546
- dims=dims,
547
- use_checkpoint=use_checkpoint,
548
- use_scale_shift_norm=use_scale_shift_norm,
549
- )
550
- ]
551
- ch = int(mult * model_channels)
552
- if ds in attention_resolutions:
553
- layers.append(
554
- AttentionBlock(
555
- ch,
556
- use_checkpoint=use_checkpoint,
557
- num_heads=num_heads,
558
- num_head_channels=num_head_channels,
559
- use_new_attention_order=use_new_attention_order,
560
- )
561
- )
562
- self.input_blocks.append(TimestepEmbedSequential(*layers))
563
- input_block_chans.append(ch)
564
- if level != len(channel_mult) - 1:
565
- out_ch = ch
566
- self.input_blocks.append(
567
- TimestepEmbedSequential(
568
- ResBlock(
569
- ch,
570
- time_embed_dim,
571
- dropout,
572
- out_channels=out_ch,
573
- dims=dims,
574
- use_checkpoint=use_checkpoint,
575
- use_scale_shift_norm=use_scale_shift_norm,
576
- down=True,
577
- )
578
- if resblock_updown
579
- else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
580
- )
581
- )
582
- ch = out_ch
583
- input_block_chans.append(ch)
584
- ds *= 2
585
-
586
- self.middle_block = TimestepEmbedSequential(
587
- ResBlock(
588
- ch,
589
- time_embed_dim,
590
- dropout,
591
- dims=dims,
592
- use_checkpoint=use_checkpoint,
593
- use_scale_shift_norm=use_scale_shift_norm,
594
- ),
595
- AttentionBlock(
596
- ch,
597
- use_checkpoint=use_checkpoint,
598
- num_heads=num_heads,
599
- num_head_channels=num_head_channels,
600
- use_new_attention_order=use_new_attention_order,
601
- ),
602
- ResBlock(
603
- ch,
604
- time_embed_dim,
605
- dropout,
606
- dims=dims,
607
- use_checkpoint=use_checkpoint,
608
- use_scale_shift_norm=use_scale_shift_norm,
609
- ),
610
- )
611
-
612
- self.output_blocks = nn.ModuleList([])
613
- for level, mult in list(enumerate(channel_mult))[::-1]:
614
- for i in range(num_res_blocks + 1):
615
- ich = input_block_chans.pop()
616
- layers = [
617
- ResBlock(
618
- ch + ich,
619
- time_embed_dim,
620
- dropout,
621
- out_channels=int(model_channels * mult),
622
- dims=dims,
623
- use_checkpoint=use_checkpoint,
624
- use_scale_shift_norm=use_scale_shift_norm,
625
- )
626
- ]
627
- ch = int(model_channels * mult)
628
- if ds in attention_resolutions:
629
- layers.append(
630
- AttentionBlock(
631
- ch,
632
- use_checkpoint=use_checkpoint,
633
- num_heads=num_heads_upsample,
634
- num_head_channels=num_head_channels,
635
- use_new_attention_order=use_new_attention_order,
636
- )
637
- )
638
- if level and i == num_res_blocks:
639
- out_ch = ch
640
- layers.append(
641
- ResBlock(
642
- ch,
643
- time_embed_dim,
644
- dropout,
645
- out_channels=out_ch,
646
- dims=dims,
647
- use_checkpoint=use_checkpoint,
648
- use_scale_shift_norm=use_scale_shift_norm,
649
- up=True,
650
- )
651
- if resblock_updown
652
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
653
- )
654
- ds //= 2
655
- self.output_blocks.append(TimestepEmbedSequential(*layers))
656
-
657
- self.out = nn.Sequential(
658
- normalization(ch),
659
- nn.SiLU(),
660
- zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
661
- )
662
-
663
- def convert_to_fp16(self):
664
- self.input_blocks.apply(convert_module_to_f16)
665
- self.middle_block.apply(convert_module_to_f16)
666
- self.output_blocks.apply(convert_module_to_f16)
667
-
668
- def convert_to_fp32(self):
669
- self.input_blocks.apply(convert_module_to_f32)
670
- self.middle_block.apply(convert_module_to_f32)
671
- self.output_blocks.apply(convert_module_to_f32)
672
-
673
- def forward(self, x, timesteps, y: Optional[torch.Tensor] = None):
674
- assert (y is not None) == (self.num_classes is not None)
675
- hs = []
676
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
677
- if self.num_classes is not None:
678
- assert y.shape == (x.shape[0],)
679
- emb = emb + self.label_emb(y)
680
-
681
- h = x.type(self.dtype)
682
- for module in self.input_blocks:
683
- h = module(h, emb)
684
- hs.append(h)
685
- h = self.middle_block(h, emb)
686
- for module in self.output_blocks:
687
- h = torch.cat([h, hs.pop()], dim=1)
688
- h = module(h, emb)
689
- h = h.type(x.dtype)
690
- return self.out(h)
691
-
692
-
693
- def _default_channel_mult(image_size: int):
694
- if image_size == 512:
695
- return (0.5, 1, 1, 2, 2, 4, 4)
696
- if image_size == 256:
697
- return (1, 1, 2, 2, 4, 4)
698
- if image_size == 128:
699
- return (1, 1, 2, 3, 4)
700
- if image_size == 64:
701
- return (1, 2, 3, 4)
702
- raise ValueError(f"unsupported image size: {image_size}")
703
-
704
-
705
- def create_adm_unet_model(
706
- image_size,
707
- num_channels,
708
- num_res_blocks,
709
- channel_mult="",
710
- learn_sigma=False,
711
- class_cond=False,
712
- use_checkpoint=False,
713
- attention_resolutions="16",
714
- num_heads=1,
715
- num_head_channels=-1,
716
- num_heads_upsample=-1,
717
- use_scale_shift_norm=False,
718
- dropout=0.0,
719
- resblock_updown=False,
720
- use_fp16=False,
721
- use_new_attention_order=False,
722
- ):
723
- channel_mult = _default_channel_mult(image_size) if channel_mult == "" else tuple(int(v) for v in channel_mult.split(","))
724
- attention_ds = tuple(image_size // int(res) for res in attention_resolutions.split(","))
725
- return UNetModel(
726
- image_size=image_size,
727
- in_channels=3,
728
- model_channels=num_channels,
729
- out_channels=(3 if not learn_sigma else 6),
730
- num_res_blocks=num_res_blocks,
731
- attention_resolutions=attention_ds,
732
- dropout=dropout,
733
- channel_mult=channel_mult,
734
- num_classes=(NUM_CLASSES if class_cond else None),
735
- use_checkpoint=use_checkpoint,
736
- use_fp16=use_fp16,
737
- num_heads=num_heads,
738
- num_head_channels=num_head_channels,
739
- num_heads_upsample=num_heads_upsample,
740
- use_scale_shift_norm=use_scale_shift_norm,
741
- resblock_updown=resblock_updown,
742
- use_new_attention_order=use_new_attention_order,
743
- )
744
-
745
-
746
- def create_adm_classifier_model(
747
- image_size: int,
748
- classifier_width: int = 128,
749
- classifier_depth: int = 2,
750
- classifier_attention_resolutions: str = "32,16,8",
751
- classifier_use_scale_shift_norm: bool = True,
752
- classifier_resblock_updown: bool = True,
753
- classifier_pool: str = "attention",
754
- use_fp16: bool = False,
755
- num_classes: int = NUM_CLASSES,
756
- ):
757
- channel_mult = _default_channel_mult(image_size)
758
- attention_ds = tuple(image_size // int(res) for res in classifier_attention_resolutions.split(","))
759
- return EncoderUNetModel(
760
- image_size=image_size,
761
- in_channels=3,
762
- model_channels=classifier_width,
763
- out_channels=num_classes,
764
- num_res_blocks=classifier_depth,
765
- attention_resolutions=attention_ds,
766
- channel_mult=channel_mult,
767
- use_fp16=use_fp16,
768
- num_head_channels=64,
769
- use_scale_shift_norm=classifier_use_scale_shift_norm,
770
- resblock_updown=classifier_resblock_updown,
771
- pool=classifier_pool,
772
- )