Xsmos commited on
Commit
3df929e
·
verified ·
1 Parent(s): e9aa37b
.context_unet_backup.py.swp ADDED
Binary file (24.6 kB). View file
 
context_unet_backup.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from dataclasses import dataclass
2
+ # import h5py
3
+ import torch
4
+ import torch.nn as nn
5
+ # from torch.utils.data import DataLoader, Dataset
6
+ # from datasets import Dataset
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import random
10
+ from abc import ABC, abstractmethod
11
+ import torch.nn.functional as F
12
+ import math
13
+ # from PIL import Image
14
+ import os
15
+ # from torch.utils.tensorboard import SummaryWriter
16
+ import copy
17
+ # from tqdm.auto import tqdm
18
+ # from torchvision import transforms
19
+ # from diffusers import UNet2DModel#, UNet3DConditionModel
20
+ # from diffusers import DDPMScheduler
21
+ # from diffusers.utils import make_image_grid
22
+ import datetime
23
+ import torch.utils.checkpoint as checkpoint
24
+ # from pathlib import Path
25
+ # from diffusers.optimization import get_cosine_schedule_with_warmup
26
+ # from accelerate import notebook_launcher, Accelerator
27
+ # from huggingface_hub import create_repo, upload_folder
28
+ # from load_h5 import Dataset4h5
29
+
30
+ class GroupNorm32(nn.GroupNorm):
31
+ def __init__(self, num_groups, num_channels, swish, eps=1e-5):
32
+ super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
33
+ self.swish = swish
34
+
35
+ def forward(self, x):
36
+ y = super().forward(x)
37
+ if self.swish == 1.0:
38
+ y = F.silu(y)
39
+ elif self.swish:
40
+ y = y * F.sigmoid(y * float(self.swish))
41
+ return y
42
+
43
+ def normalization(channels, swish=0.0):
44
+ """
45
+ Make a standard normalization layer, with an optional swish activation.
46
+
47
+ :param channels: number of input channels.
48
+ :return: an nn.Module for normalization.
49
+ """
50
+ #print (channels)
51
+ return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
52
+
53
+ Conv = {
54
+ 1: nn.Conv1d,
55
+ 2: nn.Conv2d,
56
+ 3: nn.Conv3d,
57
+ }
58
+
59
+ AvgPool = {
60
+ 1: nn.AvgPool1d,
61
+ 2: nn.AvgPool2d,
62
+ 3: nn.AvgPool3d
63
+ }
64
+
65
+ class Downsample(nn.Module):
66
+ def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2), use_checkpoint=False):
67
+ super().__init__()
68
+ self.channels = channels
69
+ self.out_channels = out_channels or channels
70
+ self.use_checkpoint = use_checkpoint
71
+ self.dim = dim
72
+ if use_conv:
73
+ self.op = Conv[dim](channels, self.out_channels, 3, stride=stride, padding=1)
74
+ else:
75
+ assert channels == self.out_channels
76
+ self.op = AvgPool[dim](kernel_size=stride, stride=stride)
77
+
78
+ def forward(self, x):
79
+ assert x.shape[1] == self.channels
80
+ if self.use_checkpoint and isinstance(self.op, Conv[self.dim]):
81
+ print(f"checkpoint working in Downsample")
82
+ return checkpoint.checkpoint(self.op, x)
83
+ else:
84
+ return self.op(x)
85
+
86
+ class Upsample(nn.Module):
87
+ def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2), use_checkpoint=False):
88
+ super().__init__()
89
+ self.channels = channels
90
+ self.out_channels = out_channels
91
+ self.use_conv = use_conv
92
+ self.stride = stride
93
+ self.use_checkpoint = use_checkpoint
94
+
95
+ if self.use_conv:
96
+ self.conv = Conv[dim](self.channels, self.out_channels, 3, padding=1)
97
+
98
+ def forward(self, x):
99
+ assert x.shape[1] == self.channels
100
+ shape = torch.tensor(x.shape[2:]) * torch.tensor(self.stride)
101
+ shape = tuple(shape.detach().numpy())
102
+ # print(shape)
103
+ x = F.interpolate(x, shape, mode='nearest')
104
+
105
+ if self.use_conv:
106
+ if self.use_checkpoint:
107
+ print(f"checkpoint working in upsample")
108
+ return checkpoint.checkpoint(self.conv, x)
109
+ else:
110
+ x = self.conv(x)
111
+
112
+ return x
113
+
114
+ def zero_module(module):
115
+ """
116
+ clean gradient of parameters of the module
117
+ """
118
+ for p in module.parameters():
119
+ p.detach().zero_()
120
+ return module
121
+
122
+ class TimestepBlock(ABC, nn.Module):
123
+ @abstractmethod
124
+ def forward(self, x, emb):
125
+ """
126
+ test
127
+ """
128
+
129
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
130
+ def forward(self, x, emb, encoder_out=None):
131
+ for layer in self:
132
+ if isinstance(layer, TimestepBlock):
133
+ x = layer(x, emb)
134
+ elif isinstance(layer, AttentionBlock):
135
+ x = layer(x, encoder_out)
136
+ else:
137
+ x = layer(x)
138
+ return x
139
+
140
+ class ResBlock(TimestepBlock):
141
+ def __init__(
142
+ self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_checkpoint=False, use_scale_shift_norm=False, up=False, down=False, dim=2, stride=(2,2),
143
+ ):
144
+ #print(f"Resblock, use_checkpoint = {use_checkpoint}")
145
+ super().__init__()
146
+ self.out_channels = out_channels or channels
147
+ self.use_scale_shift_norm = use_scale_shift_norm
148
+ self.stride = stride
149
+ self.use_checkpoint = use_checkpoint
150
+
151
+ self.in_layers = nn.Sequential(
152
+ # nn.BatchNorm2d(channels), # normalize to standard gaussian
153
+ normalization(channels, swish=1.0),
154
+ nn.Identity(),
155
+ Conv[dim](channels, self.out_channels, 3, padding=1),
156
+ )
157
+
158
+ self.updown = up or down
159
+ if up:
160
+ self.h_updown = Upsample(channels, False, dim=dim, stride=stride)
161
+ self.x_updown = Upsample(channels, False, dim=dim, stride=stride)
162
+ elif down:
163
+ self.h_updown = Downsample(channels, False, dim=dim, stride=stride)
164
+ self.x_updown = Downsample(channels, False, dim=dim, stride=stride)
165
+ else:
166
+ self.h_updown = self.x_updown = nn.Identity()
167
+
168
+ self.emb_layers = nn.Sequential(
169
+ nn.SiLU(),
170
+ nn.Linear(
171
+ emb_channels,
172
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
173
+ ),
174
+ )
175
+
176
+ self.out_layers = nn.Sequential(
177
+ # nn.BatchNorm2d(self.out_channels),
178
+ normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
179
+ nn.SiLU() if use_scale_shift_norm else nn.Identity(),
180
+ nn.Dropout(p=dropout),
181
+ zero_module(Conv[dim](self.out_channels, self.out_channels, 3, padding=1)),
182
+ )
183
+
184
+ if self.out_channels == channels:
185
+ self.skip_connection = nn.Identity()
186
+ elif use_conv:
187
+ self.skip_connection = Conv[dim](channels, self.out_channels, 3, padding=1)
188
+ else:
189
+ self.skip_connection = Conv[dim](channels, self.out_channels, 1)
190
+
191
+ def forward(self, x, emb):
192
+ if self.use_checkpoint:
193
+ return checkpoint.checkpoint(self._forward_impl, x, emb, use_reentrant=False)
194
+ else:
195
+ return self._forward_impl(x, emb)
196
+
197
+ def _forward_impl(self, x, emb):
198
+ if self.updown:
199
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
200
+ h = in_rest(x)
201
+ h = self.h_updown(h)
202
+ x = self.x_updown(x)
203
+ h = in_conv(h)
204
+ else:
205
+ h = self.in_layers(x)
206
+ emb_out = self.emb_layers(emb)#.type(h.dtype)
207
+
208
+ while len(emb_out.shape) < len(h.shape):
209
+ emb_out = emb_out[..., None]
210
+
211
+ if self.use_scale_shift_norm:
212
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
213
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
214
+ h = out_norm(h) * (1+scale) + shift
215
+ h = out_rest(h)
216
+ else:
217
+ h += emb_out
218
+ h = self.out_layers(h)
219
+ # print("ResBlock, torch.unique(h).shape =", torch.unique(h).shape)
220
+ return self.skip_connection(x) + h
221
+
222
+ class QKVAttention(nn.Module):
223
+ def __init__(self, n_heads):
224
+ super().__init__()
225
+ self.n_heads = n_heads
226
+ # print("QKVAttention, self.n_heads =", self.n_heads)
227
+
228
+ def forward(self, qkv, encoder_kv=None):
229
+ bs, width, length = qkv.shape
230
+ assert width % (3*self.n_heads) == 0
231
+ ch = width // (3*self.n_heads)
232
+
233
+ # print("QKVAttention", bs, self.n_heads, ch, length)
234
+ q, k, v = qkv.reshape(bs*self.n_heads, ch*3, length).split(ch, dim=1)
235
+ if encoder_kv is not None:
236
+ assert encoder_kv.shape[1] == self.n_heads * ch * 2
237
+ ek, ev = encoder_kv.reshape(bs*self.n_heads, ch*2, -1).split(ch, dim=1)
238
+ k = torch.cat([ek,k], dim=-1)
239
+ v = torch.cat([ev,v], dim=-1)
240
+
241
+ scale = 1 / math.sqrt(math.sqrt(ch))
242
+ weight = torch.einsum("bct,bcs->bts", q*scale, k*scale)
243
+ # print("forward, weight.dtype =", weight.dtype)
244
+ weight = torch.softmax(weight.float(), dim=-1)#.type(weight.dtype)
245
+
246
+ a = torch.einsum("bts,bcs->bct", weight, v)
247
+ return a.reshape(bs, -1, length)
248
+
249
+ class AttentionBlock(nn.Module):
250
+ def __init__(
251
+ self,
252
+ channels,
253
+ num_heads=1,
254
+ num_head_channels=-1,
255
+ use_checkpoint=False,
256
+ encoder_channels=None,
257
+ ):
258
+ #print(f"AttentionBlock, use_checkpoint = {use_checkpoint}")
259
+ super().__init__()
260
+ self.channels = channels
261
+ if num_head_channels == -1:
262
+ self.num_heads = num_heads
263
+ else:
264
+ assert channels % num_head_channels == 0,\
265
+ f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
266
+ self.num_heads = channels // num_head_channels
267
+
268
+ self.use_checkpoint = use_checkpoint
269
+ # self.norm = nn.BatchNorm2d(channels)
270
+ self.norm = normalization(channels, swish=0.0)
271
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
272
+
273
+ self.attention = QKVAttention(self.num_heads)
274
+
275
+ if encoder_channels is not None:
276
+ self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
277
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
278
+
279
+ def forward(self, x, encoder_out=None):
280
+ if self.use_checkpoint:
281
+ return checkpoint.checkpoint(self._forward_impl, x, encoder_out, use_reentrant=False)
282
+ else:
283
+ return self._forward_impl(x, encoder_out)
284
+
285
+ def _forward_impl(self, x, encoder_out=None):
286
+ b, c, *spatial = x.shape
287
+ qkv = self.qkv(self.norm(x).view(b, c, -1))
288
+ if encoder_out is not None:
289
+ encoder_out = self.encoder_kv(encoder_out)
290
+ h = self.attention(qkv, encoder_out)
291
+ else:
292
+ h = self.attention(qkv)
293
+ # print("AttentionBlock, before proj_out, torch.unique(h).shape =", torch.unique(h).shape)
294
+ h = self.proj_out(h)
295
+ # print("AttentionBlock, after proj_out, torch.unique(h).shape =", torch.unique(h).shape)
296
+ return x + h.reshape(b, c, *spatial)
297
+
298
+ def timestep_embedding(timesteps, dim, max_period=10000):
299
+ """
300
+ Create sinusoidal timestep embeddings.
301
+
302
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
303
+ These may be fractional.
304
+ :param dim: the dimension of the output.
305
+ :param max_period: controls the minimum frequency of the embeddings.
306
+ :return: an [N x dim] Tensor of positional embeddings.
307
+ """
308
+ #print(f"timestep_embedding is running")
309
+ half = dim // 2
310
+ freqs = torch.exp(
311
+ -math.log(max_period) * torch.arange(start=0, end=half) / half #, dtype=torch.float32) / half
312
+ ).to(device=timesteps.device)
313
+ #print (timesteps[:, None].float().shape,freqs[None].shape)
314
+ args = timesteps[:, None].float() * freqs[None]
315
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
316
+ if dim % 2:
317
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
318
+ #print(f"timestep_embedding is ending")
319
+ return embedding
320
+
321
+ class ContextUnet(nn.Module):
322
+ def __init__(
323
+ self,
324
+ n_param=2,
325
+ image_size=64,
326
+ in_channels=1,
327
+ model_channels=128,
328
+ out_channels = 1,
329
+ channel_mult = None,
330
+ num_res_blocks = 2,
331
+ dropout = 0,
332
+ use_checkpoint = False,
333
+ use_scale_shift_norm = False,
334
+ attention_resolutions = (16, 8),
335
+ num_heads = 4,
336
+ num_head_channels = -1,
337
+ num_heads_upsample = -1,
338
+ resblock_updown = False,
339
+ conv_resample = True,
340
+ encoder_channels = None,
341
+ dim = 2,
342
+ stride = (2,2),
343
+ #dtype = torch.float32,
344
+ ):
345
+ super().__init__()
346
+ #self.use_checkpoint = use_checkpoint
347
+
348
+ if channel_mult == None:
349
+ if image_size == 512:
350
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
351
+ elif image_size == 256:
352
+ channel_mult = (1, 1, 2, 2, 4, 4)
353
+ elif image_size == 128:
354
+ channel_mult = (1, 1, 2, 3, 4)
355
+ elif image_size == 64:
356
+ channel_mult = (1,2,2,2,4)#(1,1,2,2,4)#(1,1,1,2,2)#(0.5,1,1,2,2)#(1,1,2)#(1,2)#(1,1,2,2)#(1,1,2,2,4)#(2,2,4,4,4)#(1, 2, 4)#(2,4,4,4,8)#(1, 2, 2, 4, 4)#(1, 2, 2, 4, 8)#(1, 1, 2, 2, 4, 4)#(1, 2, 4, 8, 16)#(1, 2, 3, 4)#(1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
357
+ elif image_size == 32:
358
+ channel_mult = (1, 2, 2, 4)
359
+ elif image_size == 28:
360
+ channel_mult = (1, 2, 4)#(1, 2, 3, 4)
361
+ else:
362
+ raise ValueError(f"unsupported image size: {image_size}")
363
+ # else:
364
+ # channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
365
+
366
+ attention_ds = []
367
+ for res in attention_resolutions:
368
+ attention_ds.append(image_size // int(res))
369
+
370
+ # print("before, ContextUnet, num_heads_upsample =", num_heads_upsample, "num_heads =", num_heads)
371
+ if num_heads_upsample == -1:
372
+ num_heads_upsample = num_heads
373
+ # print("after, ContextUnet, num_heads_upsample =", num_heads_upsample, "num_heads =", num_heads)
374
+
375
+ # self.n_param = n_param
376
+ self.model_channels = model_channels
377
+ # self.use_fp16 = use_fp16
378
+ #self.dtype = dtype#torch.float16 if self.use_fp16 else torch.float32
379
+
380
+ self.token_embedding = nn.Linear(n_param, model_channels * 4)
381
+
382
+ time_embed_dim = model_channels * 4
383
+ self.time_embed = nn.Sequential(
384
+ nn.Linear(model_channels, time_embed_dim),
385
+ nn.SiLU(),
386
+ nn.Linear(time_embed_dim, time_embed_dim),
387
+ )
388
+
389
+ ch = input_ch = int(channel_mult[0] * model_channels)
390
+
391
+ ###################### input_blocks ######################
392
+ self.input_blocks = nn.ModuleList(
393
+ [TimestepEmbedSequential(Conv[dim](in_channels, ch, 3, padding=1))]
394
+ )
395
+ self._feature_size = ch
396
+ input_block_chans = [ch]
397
+ ds = 1
398
+
399
+ for level, mult in enumerate(channel_mult):
400
+ for _ in range(num_res_blocks):
401
+ layers = [
402
+ ResBlock(
403
+ ch,
404
+ time_embed_dim,
405
+ dropout,
406
+ out_channels = int(mult * model_channels),
407
+ use_checkpoint = use_checkpoint,
408
+ use_scale_shift_norm = use_scale_shift_norm,
409
+ dim = dim,
410
+ stride = stride,
411
+ )
412
+ ]
413
+ ch = int(mult * model_channels)
414
+ if ds in attention_ds:
415
+ layers.append(
416
+ AttentionBlock(
417
+ ch,
418
+ use_checkpoint=use_checkpoint,
419
+ num_heads = num_heads,
420
+ num_head_channels = num_head_channels,
421
+ encoder_channels = encoder_channels,
422
+ )
423
+ )
424
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
425
+ self._feature_size += ch
426
+ input_block_chans.append(ch)
427
+
428
+ if level != len(channel_mult) - 1:
429
+ out_ch = ch
430
+ self.input_blocks.append(
431
+ TimestepEmbedSequential(
432
+ ResBlock(
433
+ ch,
434
+ time_embed_dim,
435
+ dropout,
436
+ out_channels=out_ch,
437
+ # dims=dims,
438
+ use_checkpoint=use_checkpoint,
439
+ use_scale_shift_norm=use_scale_shift_norm,
440
+ down=True,
441
+ dim = dim,
442
+ stride = stride,
443
+ )
444
+ if resblock_updown
445
+ else Downsample(ch,
446
+ conv_resample,
447
+ out_channels=out_ch,
448
+ dim=dim,
449
+ stride=stride,
450
+ #use_checkpoint=use_checkpoint,
451
+ )
452
+ )
453
+ )
454
+ ch = out_ch
455
+ input_block_chans.append(ch)
456
+ ds *= 2
457
+ self._feature_size += ch
458
+
459
+
460
+ ###################### middle_blocks ######################
461
+ self.middle_block = TimestepEmbedSequential(
462
+ ResBlock(
463
+ ch,
464
+ time_embed_dim,
465
+ dropout,
466
+ use_checkpoint=use_checkpoint,
467
+ use_scale_shift_norm=use_scale_shift_norm,
468
+ dim = dim,
469
+ stride = stride,
470
+ ),
471
+ AttentionBlock(
472
+ ch,
473
+ use_checkpoint=use_checkpoint,
474
+ num_heads=num_heads,
475
+ num_head_channels=num_head_channels,
476
+ encoder_channels=encoder_channels,
477
+ ),
478
+ ResBlock(
479
+ ch,
480
+ time_embed_dim,
481
+ dropout,
482
+ use_checkpoint=use_checkpoint,
483
+ use_scale_shift_norm=use_scale_shift_norm,
484
+ dim = dim,
485
+ stride = stride,
486
+ ),
487
+ )
488
+ self._feature_size += ch
489
+
490
+
491
+ ###################### output_blocks ######################
492
+ self.output_blocks = nn.ModuleList([])
493
+ for level, mult in list(enumerate(channel_mult))[::-1]:
494
+ for i in range(num_res_blocks + 1):
495
+ ich = input_block_chans.pop()
496
+ layers = [
497
+ ResBlock(
498
+ ch + ich,
499
+ time_embed_dim,
500
+ dropout,
501
+ out_channels=int(model_channels * mult),
502
+ # dims=dims,
503
+ use_checkpoint=use_checkpoint,
504
+ use_scale_shift_norm=use_scale_shift_norm,
505
+ dim = dim,
506
+ stride = stride,
507
+ )
508
+ ]
509
+ ch = int(model_channels * mult)
510
+ if ds in attention_ds:
511
+ # print("ds in attention_resolutions, num_heads=", num_heads_upsample)
512
+ layers.append(
513
+ AttentionBlock(
514
+ ch,
515
+ use_checkpoint=use_checkpoint,
516
+ num_heads=num_heads_upsample,
517
+ num_head_channels=num_head_channels,
518
+ encoder_channels=encoder_channels,
519
+ )
520
+ )
521
+ if level and i == num_res_blocks:
522
+ out_ch = ch
523
+ layers.append(
524
+ ResBlock(
525
+ ch,
526
+ time_embed_dim,
527
+ dropout,
528
+ out_channels=out_ch,
529
+ # dims=dims,
530
+ use_checkpoint=use_checkpoint,
531
+ use_scale_shift_norm=use_scale_shift_norm,
532
+ up=True,
533
+ dim = dim,
534
+ stride = stride,
535
+ )
536
+ if resblock_updown
537
+ else Upsample(ch,
538
+ conv_resample,
539
+ out_channels=out_ch,
540
+ dim=dim,
541
+ stride=stride,
542
+ #use_checkpoint=use_checkpoint,
543
+ )
544
+ )
545
+ ds //= 2
546
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
547
+ self._feature_size += ch
548
+
549
+ self.out = nn.Sequential(
550
+ # nn.BatchNorm2d(ch),
551
+ normalization(ch, swish=1.0),
552
+ nn.Identity(),
553
+ zero_module(Conv[dim](input_ch, out_channels, 3, padding=1)),
554
+ )
555
+ # self.use_fp16 = use_fp16
556
+
557
+ def forward(self, x, timesteps, y=None):
558
+ hs = []
559
+ # print("device of timesteps, self.model_channels:", timesteps.device, self.model_channels)
560
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))#.to(self.dtype))
561
+ #print(f"forward after emb")
562
+ if y != None:
563
+ #text_outputs = self.token_embedding(y.float())
564
+ text_outputs = self.token_embedding(y)#.to(self.dtype))
565
+ emb = emb + text_outputs.to(emb)
566
+
567
+ #print("forward, h = x.type(self.dtype), self.dtype =", self.dtype)
568
+ h = x.clone()#.type(self.dtype)
569
+ #print("0,h.shape =", h.shape)
570
+ for module in self.input_blocks:
571
+ h = module(h, emb)
572
+ #print(f"in for loop, h.shape = {h.shape}")
573
+ hs.append(h)
574
+ #print("module encoder, h.shape =", h.shape)
575
+ #print("before middle block, h.shape =", h.shape)
576
+ h = self.middle_block(h, emb)
577
+ #print("after middle block, h.shape =", h.shape)
578
+ #print("2, h.dtype =", h.dtype)
579
+ for module in self.output_blocks:
580
+ #print("for module in self.output_blocks, h.shape =", h.shape)
581
+ # print("len(hs) =", len(hs), ", hs[-1].shape =", hs[-1].shape)
582
+ h = torch.cat([h, hs.pop()], dim=1)
583
+ h = module(h, emb)
584
+ # print("module decoder, h.shape =", h.shape)
585
+
586
+ #print("h = h.type(x.dtype), x.dtype =", x.dtype, h.dtype)
587
+ #h = h.type(x.dtype)
588
+ h = self.out(h)
589
+ #print("self.out(h)", "h.dtype =", h.dtype)
590
+
591
+ return h