heboya8 commited on
Commit
c9e511a
·
verified ·
1 Parent(s): a96e046

Upload unet_3d_blocks.py

Browse files
Files changed (1) hide show
  1. models/unet_3d_blocks.py +875 -0
models/unet_3d_blocks.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.utils.checkpoint as checkpoint
17
+ from torch import nn
18
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
19
+ from diffusers.models.transformer_2d import Transformer2DModel
20
+ from diffusers.models.transformer_temporal import TransformerTemporalModel
21
+
22
+ # Assign gradient checkpoint function to simple variable for readability.
23
+ g_c = checkpoint.checkpoint
24
+
25
+ def is_video(num_frames, only_video=True):
26
+ if num_frames == 1 and not only_video:
27
+ return False
28
+ return num_frames > 1
29
+
30
+ def custom_checkpoint(module, mode=None):
31
+ if mode == None:
32
+ raise ValueError('Mode for gradient checkpointing cannot be none.')
33
+
34
+ custom_forward = None
35
+
36
+ if mode == 'resnet':
37
+ def custom_forward(hidden_states, temb):
38
+ inputs = module(hidden_states, temb)
39
+ return inputs
40
+
41
+ if mode == 'attn':
42
+ def custom_forward(
43
+ hidden_states,
44
+ encoder_hidden_states=None,
45
+ cross_attention_kwargs=None,
46
+ attention_mask=None,
47
+ ):
48
+ inputs = module(
49
+ hidden_states,
50
+ encoder_hidden_states,
51
+ cross_attention_kwargs,
52
+ attention_mask
53
+ )
54
+ return inputs.sample
55
+
56
+ if mode == 'temp':
57
+ # If inputs are not None, we can assume that this was a single image.
58
+ # Otherwise, do temporal convolutions / attention.
59
+ def custom_forward(hidden_states, num_frames=None):
60
+ if not is_video(num_frames):
61
+ return hidden_states
62
+ else:
63
+ inputs = module(
64
+ hidden_states,
65
+ num_frames=num_frames
66
+ )
67
+ if isinstance(module, TransformerTemporalModel):
68
+ return inputs.sample
69
+ else:
70
+ return inputs
71
+
72
+ return custom_forward
73
+
74
+ def transformer_g_c(transformer, sample, num_frames):
75
+ sample = g_c(custom_checkpoint(transformer, mode='temp'),
76
+ sample, num_frames, use_reentrant=False,
77
+ )
78
+ return sample
79
+
80
+ def cross_attn_g_c(
81
+ attn,
82
+ temp_attn,
83
+ resnet,
84
+ temp_conv,
85
+ hidden_states,
86
+ encoder_hidden_states,
87
+ cross_attention_kwargs,
88
+ temb,
89
+ num_frames,
90
+ inverse_temp=False,
91
+ attention_mask=None,
92
+ ):
93
+
94
+ def ordered_g_c(idx):
95
+
96
+ # Self and CrossAttention
97
+ if idx == 0:
98
+ return g_c(custom_checkpoint(attn, mode='attn'),
99
+ hidden_states,
100
+ encoder_hidden_states,
101
+ cross_attention_kwargs,
102
+ attention_mask,
103
+ use_reentrant=False
104
+ )
105
+
106
+ # Temporal Self and CrossAttention
107
+ if idx == 1:
108
+ return g_c(custom_checkpoint(temp_attn, mode='temp'),
109
+ hidden_states,
110
+ num_frames,
111
+ use_reentrant=False
112
+ )
113
+
114
+ # Resnets
115
+ if idx == 2:
116
+ return g_c(custom_checkpoint(resnet, mode='resnet'),
117
+ hidden_states,
118
+ temb,
119
+ use_reentrant=False
120
+ )
121
+
122
+ # Temporal Convolutions
123
+ if idx == 3:
124
+ return g_c(custom_checkpoint(temp_conv, mode='temp'),
125
+ hidden_states,
126
+ num_frames,
127
+ use_reentrant=False
128
+ )
129
+
130
+ # Here we call the function depending on the order in which they are called.
131
+ # For some layers, the orders are different, so we access the appropriate one by index.
132
+
133
+ if not inverse_temp:
134
+ for idx in [0,1,2,3]:
135
+ hidden_states = ordered_g_c(idx)
136
+ else:
137
+ for idx in [2,3,0,1]:
138
+ hidden_states = ordered_g_c(idx)
139
+
140
+ return hidden_states
141
+
142
+ def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames):
143
+ hidden_states = g_c(custom_checkpoint(resnet, mode='resnet'),
144
+ hidden_states,
145
+ temb,
146
+ use_reentrant=False
147
+ )
148
+ hidden_states = g_c(custom_checkpoint(temp_conv, mode='temp'),
149
+ hidden_states,
150
+ num_frames,
151
+ use_reentrant=False
152
+ )
153
+ return hidden_states
154
+
155
+ def get_down_block(
156
+ down_block_type,
157
+ num_layers,
158
+ in_channels,
159
+ out_channels,
160
+ temb_channels,
161
+ add_downsample,
162
+ resnet_eps,
163
+ resnet_act_fn,
164
+ attn_num_head_channels,
165
+ resnet_groups=None,
166
+ cross_attention_dim=None,
167
+ downsample_padding=None,
168
+ dual_cross_attention=False,
169
+ use_linear_projection=True,
170
+ only_cross_attention=False,
171
+ upcast_attention=False,
172
+ resnet_time_scale_shift="default",
173
+ ):
174
+ if down_block_type == "DownBlock3D":
175
+ return DownBlock3D(
176
+ num_layers=num_layers,
177
+ in_channels=in_channels,
178
+ out_channels=out_channels,
179
+ temb_channels=temb_channels,
180
+ add_downsample=add_downsample,
181
+ resnet_eps=resnet_eps,
182
+ resnet_act_fn=resnet_act_fn,
183
+ resnet_groups=resnet_groups,
184
+ downsample_padding=downsample_padding,
185
+ resnet_time_scale_shift=resnet_time_scale_shift,
186
+ )
187
+ elif down_block_type == "CrossAttnDownBlock3D":
188
+ if cross_attention_dim is None:
189
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
190
+ return CrossAttnDownBlock3D(
191
+ num_layers=num_layers,
192
+ in_channels=in_channels,
193
+ out_channels=out_channels,
194
+ temb_channels=temb_channels,
195
+ add_downsample=add_downsample,
196
+ resnet_eps=resnet_eps,
197
+ resnet_act_fn=resnet_act_fn,
198
+ resnet_groups=resnet_groups,
199
+ downsample_padding=downsample_padding,
200
+ cross_attention_dim=cross_attention_dim,
201
+ attn_num_head_channels=attn_num_head_channels,
202
+ dual_cross_attention=dual_cross_attention,
203
+ use_linear_projection=use_linear_projection,
204
+ only_cross_attention=only_cross_attention,
205
+ upcast_attention=upcast_attention,
206
+ resnet_time_scale_shift=resnet_time_scale_shift,
207
+ )
208
+ raise ValueError(f"{down_block_type} does not exist.")
209
+
210
+
211
+ def get_up_block(
212
+ up_block_type,
213
+ num_layers,
214
+ in_channels,
215
+ out_channels,
216
+ prev_output_channel,
217
+ temb_channels,
218
+ add_upsample,
219
+ resnet_eps,
220
+ resnet_act_fn,
221
+ attn_num_head_channels,
222
+ resnet_groups=None,
223
+ cross_attention_dim=None,
224
+ dual_cross_attention=False,
225
+ use_linear_projection=True,
226
+ only_cross_attention=False,
227
+ upcast_attention=False,
228
+ resnet_time_scale_shift="default",
229
+ ):
230
+ if up_block_type == "UpBlock3D":
231
+ return UpBlock3D(
232
+ num_layers=num_layers,
233
+ in_channels=in_channels,
234
+ out_channels=out_channels,
235
+ prev_output_channel=prev_output_channel,
236
+ temb_channels=temb_channels,
237
+ add_upsample=add_upsample,
238
+ resnet_eps=resnet_eps,
239
+ resnet_act_fn=resnet_act_fn,
240
+ resnet_groups=resnet_groups,
241
+ resnet_time_scale_shift=resnet_time_scale_shift,
242
+ )
243
+ elif up_block_type == "CrossAttnUpBlock3D":
244
+ if cross_attention_dim is None:
245
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
246
+ return CrossAttnUpBlock3D(
247
+ num_layers=num_layers,
248
+ in_channels=in_channels,
249
+ out_channels=out_channels,
250
+ prev_output_channel=prev_output_channel,
251
+ temb_channels=temb_channels,
252
+ add_upsample=add_upsample,
253
+ resnet_eps=resnet_eps,
254
+ resnet_act_fn=resnet_act_fn,
255
+ resnet_groups=resnet_groups,
256
+ cross_attention_dim=cross_attention_dim,
257
+ attn_num_head_channels=attn_num_head_channels,
258
+ dual_cross_attention=dual_cross_attention,
259
+ use_linear_projection=use_linear_projection,
260
+ only_cross_attention=only_cross_attention,
261
+ upcast_attention=upcast_attention,
262
+ resnet_time_scale_shift=resnet_time_scale_shift,
263
+ )
264
+ raise ValueError(f"{up_block_type} does not exist.")
265
+
266
+
267
+ class UNetMidBlock3DCrossAttn(nn.Module):
268
+ def __init__(
269
+ self,
270
+ in_channels: int,
271
+ temb_channels: int,
272
+ dropout: float = 0.0,
273
+ num_layers: int = 1,
274
+ resnet_eps: float = 1e-6,
275
+ resnet_time_scale_shift: str = "default",
276
+ resnet_act_fn: str = "swish",
277
+ resnet_groups: int = 32,
278
+ resnet_pre_norm: bool = True,
279
+ attn_num_head_channels=1,
280
+ output_scale_factor=1.0,
281
+ cross_attention_dim=1280,
282
+ dual_cross_attention=False,
283
+ use_linear_projection=True,
284
+ upcast_attention=False,
285
+ ):
286
+ super().__init__()
287
+
288
+ self.gradient_checkpointing = False
289
+ self.has_cross_attention = True
290
+ self.attn_num_head_channels = attn_num_head_channels
291
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
292
+
293
+ # there is always at least one resnet
294
+ resnets = [
295
+ ResnetBlock2D(
296
+ in_channels=in_channels,
297
+ out_channels=in_channels,
298
+ temb_channels=temb_channels,
299
+ eps=resnet_eps,
300
+ groups=resnet_groups,
301
+ dropout=dropout,
302
+ time_embedding_norm=resnet_time_scale_shift,
303
+ non_linearity=resnet_act_fn,
304
+ output_scale_factor=output_scale_factor,
305
+ pre_norm=resnet_pre_norm,
306
+ )
307
+ ]
308
+ temp_convs = [
309
+ TemporalConvLayer(
310
+ in_channels,
311
+ in_channels,
312
+ dropout=0.1
313
+ )
314
+ ]
315
+ attentions = []
316
+ temp_attentions = []
317
+
318
+ for _ in range(num_layers):
319
+ attentions.append(
320
+ Transformer2DModel(
321
+ in_channels // attn_num_head_channels,
322
+ attn_num_head_channels,
323
+ in_channels=in_channels,
324
+ num_layers=1,
325
+ cross_attention_dim=cross_attention_dim,
326
+ norm_num_groups=resnet_groups,
327
+ use_linear_projection=use_linear_projection,
328
+ upcast_attention=upcast_attention,
329
+ )
330
+ )
331
+ temp_attentions.append(
332
+ TransformerTemporalModel(
333
+ in_channels // attn_num_head_channels,
334
+ attn_num_head_channels,
335
+ in_channels=in_channels,
336
+ num_layers=1,
337
+ cross_attention_dim=cross_attention_dim,
338
+ norm_num_groups=resnet_groups,
339
+ )
340
+ )
341
+ resnets.append(
342
+ ResnetBlock2D(
343
+ in_channels=in_channels,
344
+ out_channels=in_channels,
345
+ temb_channels=temb_channels,
346
+ eps=resnet_eps,
347
+ groups=resnet_groups,
348
+ dropout=dropout,
349
+ time_embedding_norm=resnet_time_scale_shift,
350
+ non_linearity=resnet_act_fn,
351
+ output_scale_factor=output_scale_factor,
352
+ pre_norm=resnet_pre_norm,
353
+ )
354
+ )
355
+ temp_convs.append(
356
+ TemporalConvLayer(
357
+ in_channels,
358
+ in_channels,
359
+ dropout=0.1
360
+ )
361
+ )
362
+
363
+ self.resnets = nn.ModuleList(resnets)
364
+ self.temp_convs = nn.ModuleList(temp_convs)
365
+ self.attentions = nn.ModuleList(attentions)
366
+ self.temp_attentions = nn.ModuleList(temp_attentions)
367
+
368
+ def forward(
369
+ self,
370
+ hidden_states,
371
+ temb=None,
372
+ encoder_hidden_states=None,
373
+ attention_mask=None,
374
+ num_frames=1,
375
+ cross_attention_kwargs=None,
376
+ ):
377
+ if self.gradient_checkpointing:
378
+ hidden_states = up_down_g_c(
379
+ self.resnets[0],
380
+ self.temp_convs[0],
381
+ hidden_states,
382
+ temb,
383
+ num_frames
384
+ )
385
+ else:
386
+ hidden_states = self.resnets[0](hidden_states, temb)
387
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
388
+
389
+ for attn, temp_attn, resnet, temp_conv in zip(
390
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
391
+ ):
392
+ if self.gradient_checkpointing:
393
+ hidden_states = cross_attn_g_c(
394
+ attn,
395
+ temp_attn,
396
+ resnet,
397
+ temp_conv,
398
+ hidden_states,
399
+ encoder_hidden_states,
400
+ cross_attention_kwargs,
401
+ temb,
402
+ num_frames
403
+ )
404
+ else:
405
+ hidden_states = attn(
406
+ hidden_states,
407
+ encoder_hidden_states=encoder_hidden_states,
408
+ cross_attention_kwargs=cross_attention_kwargs,
409
+ ).sample
410
+
411
+ if num_frames > 1:
412
+ hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
413
+
414
+ hidden_states = resnet(hidden_states, temb)
415
+
416
+ if num_frames > 1:
417
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
418
+
419
+ return hidden_states
420
+
421
+
422
+ class CrossAttnDownBlock3D(nn.Module):
423
+ def __init__(
424
+ self,
425
+ in_channels: int,
426
+ out_channels: int,
427
+ temb_channels: int,
428
+ dropout: float = 0.0,
429
+ num_layers: int = 1,
430
+ resnet_eps: float = 1e-6,
431
+ resnet_time_scale_shift: str = "default",
432
+ resnet_act_fn: str = "swish",
433
+ resnet_groups: int = 32,
434
+ resnet_pre_norm: bool = True,
435
+ attn_num_head_channels=1,
436
+ cross_attention_dim=1280,
437
+ output_scale_factor=1.0,
438
+ downsample_padding=1,
439
+ add_downsample=True,
440
+ dual_cross_attention=False,
441
+ use_linear_projection=False,
442
+ only_cross_attention=False,
443
+ upcast_attention=False,
444
+ ):
445
+ super().__init__()
446
+ resnets = []
447
+ attentions = []
448
+ temp_attentions = []
449
+ temp_convs = []
450
+
451
+ self.gradient_checkpointing = False
452
+ self.has_cross_attention = True
453
+ self.attn_num_head_channels = attn_num_head_channels
454
+
455
+ for i in range(num_layers):
456
+ in_channels = in_channels if i == 0 else out_channels
457
+ resnets.append(
458
+ ResnetBlock2D(
459
+ in_channels=in_channels,
460
+ out_channels=out_channels,
461
+ temb_channels=temb_channels,
462
+ eps=resnet_eps,
463
+ groups=resnet_groups,
464
+ dropout=dropout,
465
+ time_embedding_norm=resnet_time_scale_shift,
466
+ non_linearity=resnet_act_fn,
467
+ output_scale_factor=output_scale_factor,
468
+ pre_norm=resnet_pre_norm,
469
+ )
470
+ )
471
+ temp_convs.append(
472
+ TemporalConvLayer(
473
+ out_channels,
474
+ out_channels,
475
+ dropout=0.1
476
+ )
477
+ )
478
+ attentions.append(
479
+ Transformer2DModel(
480
+ out_channels // attn_num_head_channels,
481
+ attn_num_head_channels,
482
+ in_channels=out_channels,
483
+ num_layers=1,
484
+ cross_attention_dim=cross_attention_dim,
485
+ norm_num_groups=resnet_groups,
486
+ use_linear_projection=use_linear_projection,
487
+ only_cross_attention=only_cross_attention,
488
+ upcast_attention=upcast_attention,
489
+ )
490
+ )
491
+ temp_attentions.append(
492
+ TransformerTemporalModel(
493
+ out_channels // attn_num_head_channels,
494
+ attn_num_head_channels,
495
+ in_channels=out_channels,
496
+ num_layers=1,
497
+ cross_attention_dim=cross_attention_dim,
498
+ norm_num_groups=resnet_groups,
499
+ )
500
+ )
501
+ self.resnets = nn.ModuleList(resnets)
502
+ self.temp_convs = nn.ModuleList(temp_convs)
503
+ self.attentions = nn.ModuleList(attentions)
504
+ self.temp_attentions = nn.ModuleList(temp_attentions)
505
+
506
+ if add_downsample:
507
+ self.downsamplers = nn.ModuleList(
508
+ [
509
+ Downsample2D(
510
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
511
+ )
512
+ ]
513
+ )
514
+ else:
515
+ self.downsamplers = None
516
+
517
+ def forward(
518
+ self,
519
+ hidden_states,
520
+ temb=None,
521
+ encoder_hidden_states=None,
522
+ attention_mask=None,
523
+ num_frames=1,
524
+ cross_attention_kwargs=None,
525
+ ):
526
+ # TODO(Patrick, William) - attention mask is not used
527
+ output_states = ()
528
+
529
+ for resnet, temp_conv, attn, temp_attn in zip(
530
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
531
+ ):
532
+
533
+ if self.gradient_checkpointing:
534
+ hidden_states = cross_attn_g_c(
535
+ attn,
536
+ temp_attn,
537
+ resnet,
538
+ temp_conv,
539
+ hidden_states,
540
+ encoder_hidden_states,
541
+ cross_attention_kwargs,
542
+ temb,
543
+ num_frames,
544
+ inverse_temp=True
545
+ )
546
+ else:
547
+ hidden_states = resnet(hidden_states, temb)
548
+
549
+ if num_frames > 1:
550
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
551
+
552
+ hidden_states = attn(
553
+ hidden_states,
554
+ encoder_hidden_states=encoder_hidden_states,
555
+ cross_attention_kwargs=cross_attention_kwargs,
556
+ ).sample
557
+
558
+ if num_frames > 1:
559
+ hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
560
+
561
+ output_states += (hidden_states,)
562
+
563
+ if self.downsamplers is not None:
564
+ for downsampler in self.downsamplers:
565
+ hidden_states = downsampler(hidden_states)
566
+
567
+ output_states += (hidden_states,)
568
+
569
+ return hidden_states, output_states
570
+
571
+
572
+ class DownBlock3D(nn.Module):
573
+ def __init__(
574
+ self,
575
+ in_channels: int,
576
+ out_channels: int,
577
+ temb_channels: int,
578
+ dropout: float = 0.0,
579
+ num_layers: int = 1,
580
+ resnet_eps: float = 1e-6,
581
+ resnet_time_scale_shift: str = "default",
582
+ resnet_act_fn: str = "swish",
583
+ resnet_groups: int = 32,
584
+ resnet_pre_norm: bool = True,
585
+ output_scale_factor=1.0,
586
+ add_downsample=True,
587
+ downsample_padding=1,
588
+ ):
589
+ super().__init__()
590
+ resnets = []
591
+ temp_convs = []
592
+
593
+ self.gradient_checkpointing = False
594
+ for i in range(num_layers):
595
+ in_channels = in_channels if i == 0 else out_channels
596
+ resnets.append(
597
+ ResnetBlock2D(
598
+ in_channels=in_channels,
599
+ out_channels=out_channels,
600
+ temb_channels=temb_channels,
601
+ eps=resnet_eps,
602
+ groups=resnet_groups,
603
+ dropout=dropout,
604
+ time_embedding_norm=resnet_time_scale_shift,
605
+ non_linearity=resnet_act_fn,
606
+ output_scale_factor=output_scale_factor,
607
+ pre_norm=resnet_pre_norm,
608
+ )
609
+ )
610
+ temp_convs.append(
611
+ TemporalConvLayer(
612
+ out_channels,
613
+ out_channels,
614
+ dropout=0.1
615
+ )
616
+ )
617
+
618
+ self.resnets = nn.ModuleList(resnets)
619
+ self.temp_convs = nn.ModuleList(temp_convs)
620
+
621
+ if add_downsample:
622
+ self.downsamplers = nn.ModuleList(
623
+ [
624
+ Downsample2D(
625
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
626
+ )
627
+ ]
628
+ )
629
+ else:
630
+ self.downsamplers = None
631
+
632
+ def forward(self, hidden_states, temb=None, num_frames=1):
633
+ output_states = ()
634
+
635
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
636
+ if self.gradient_checkpointing:
637
+ hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames)
638
+ else:
639
+ hidden_states = resnet(hidden_states, temb)
640
+
641
+ if num_frames > 1:
642
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
643
+
644
+ output_states += (hidden_states,)
645
+
646
+ if self.downsamplers is not None:
647
+ for downsampler in self.downsamplers:
648
+ hidden_states = downsampler(hidden_states)
649
+
650
+ output_states += (hidden_states,)
651
+
652
+ return hidden_states, output_states
653
+
654
+
655
+ class CrossAttnUpBlock3D(nn.Module):
656
+ def __init__(
657
+ self,
658
+ in_channels: int,
659
+ out_channels: int,
660
+ prev_output_channel: int,
661
+ temb_channels: int,
662
+ dropout: float = 0.0,
663
+ num_layers: int = 1,
664
+ resnet_eps: float = 1e-6,
665
+ resnet_time_scale_shift: str = "default",
666
+ resnet_act_fn: str = "swish",
667
+ resnet_groups: int = 32,
668
+ resnet_pre_norm: bool = True,
669
+ attn_num_head_channels=1,
670
+ cross_attention_dim=1280,
671
+ output_scale_factor=1.0,
672
+ add_upsample=True,
673
+ dual_cross_attention=False,
674
+ use_linear_projection=False,
675
+ only_cross_attention=False,
676
+ upcast_attention=False,
677
+ ):
678
+ super().__init__()
679
+ resnets = []
680
+ temp_convs = []
681
+ attentions = []
682
+ temp_attentions = []
683
+
684
+ self.gradient_checkpointing = False
685
+ self.has_cross_attention = True
686
+ self.attn_num_head_channels = attn_num_head_channels
687
+
688
+ for i in range(num_layers):
689
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
690
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
691
+
692
+ resnets.append(
693
+ ResnetBlock2D(
694
+ in_channels=resnet_in_channels + res_skip_channels,
695
+ out_channels=out_channels,
696
+ temb_channels=temb_channels,
697
+ eps=resnet_eps,
698
+ groups=resnet_groups,
699
+ dropout=dropout,
700
+ time_embedding_norm=resnet_time_scale_shift,
701
+ non_linearity=resnet_act_fn,
702
+ output_scale_factor=output_scale_factor,
703
+ pre_norm=resnet_pre_norm,
704
+ )
705
+ )
706
+ temp_convs.append(
707
+ TemporalConvLayer(
708
+ out_channels,
709
+ out_channels,
710
+ dropout=0.1
711
+ )
712
+ )
713
+ attentions.append(
714
+ Transformer2DModel(
715
+ out_channels // attn_num_head_channels,
716
+ attn_num_head_channels,
717
+ in_channels=out_channels,
718
+ num_layers=1,
719
+ cross_attention_dim=cross_attention_dim,
720
+ norm_num_groups=resnet_groups,
721
+ use_linear_projection=use_linear_projection,
722
+ only_cross_attention=only_cross_attention,
723
+ upcast_attention=upcast_attention,
724
+ )
725
+ )
726
+ temp_attentions.append(
727
+ TransformerTemporalModel(
728
+ out_channels // attn_num_head_channels,
729
+ attn_num_head_channels,
730
+ in_channels=out_channels,
731
+ num_layers=1,
732
+ cross_attention_dim=cross_attention_dim,
733
+ norm_num_groups=resnet_groups,
734
+ )
735
+ )
736
+ self.resnets = nn.ModuleList(resnets)
737
+ self.temp_convs = nn.ModuleList(temp_convs)
738
+ self.attentions = nn.ModuleList(attentions)
739
+ self.temp_attentions = nn.ModuleList(temp_attentions)
740
+
741
+ if add_upsample:
742
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
743
+ else:
744
+ self.upsamplers = None
745
+
746
+ def forward(
747
+ self,
748
+ hidden_states,
749
+ res_hidden_states_tuple,
750
+ temb=None,
751
+ encoder_hidden_states=None,
752
+ upsample_size=None,
753
+ attention_mask=None,
754
+ num_frames=1,
755
+ cross_attention_kwargs=None,
756
+ ):
757
+ # TODO(Patrick, William) - attention mask is not used
758
+ for resnet, temp_conv, attn, temp_attn in zip(
759
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
760
+ ):
761
+ # pop res hidden states
762
+ res_hidden_states = res_hidden_states_tuple[-1]
763
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
764
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
765
+
766
+ if self.gradient_checkpointing:
767
+ hidden_states = cross_attn_g_c(
768
+ attn,
769
+ temp_attn,
770
+ resnet,
771
+ temp_conv,
772
+ hidden_states,
773
+ encoder_hidden_states,
774
+ cross_attention_kwargs,
775
+ temb,
776
+ num_frames,
777
+ inverse_temp=True
778
+ )
779
+ else:
780
+ hidden_states = resnet(hidden_states, temb)
781
+
782
+ if num_frames > 1:
783
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
784
+
785
+ hidden_states = attn(
786
+ hidden_states,
787
+ encoder_hidden_states=encoder_hidden_states,
788
+ cross_attention_kwargs=cross_attention_kwargs,
789
+ ).sample
790
+
791
+ if num_frames > 1:
792
+ hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
793
+
794
+ if self.upsamplers is not None:
795
+ for upsampler in self.upsamplers:
796
+ hidden_states = upsampler(hidden_states, upsample_size)
797
+
798
+ return hidden_states
799
+
800
+
801
+ class UpBlock3D(nn.Module):
802
+ def __init__(
803
+ self,
804
+ in_channels: int,
805
+ prev_output_channel: int,
806
+ out_channels: int,
807
+ temb_channels: int,
808
+ dropout: float = 0.0,
809
+ num_layers: int = 1,
810
+ resnet_eps: float = 1e-6,
811
+ resnet_time_scale_shift: str = "default",
812
+ resnet_act_fn: str = "swish",
813
+ resnet_groups: int = 32,
814
+ resnet_pre_norm: bool = True,
815
+ output_scale_factor=1.0,
816
+ add_upsample=True,
817
+ ):
818
+ super().__init__()
819
+ resnets = []
820
+ temp_convs = []
821
+ self.gradient_checkpointing = False
822
+ for i in range(num_layers):
823
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
824
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
825
+
826
+ resnets.append(
827
+ ResnetBlock2D(
828
+ in_channels=resnet_in_channels + res_skip_channels,
829
+ out_channels=out_channels,
830
+ temb_channels=temb_channels,
831
+ eps=resnet_eps,
832
+ groups=resnet_groups,
833
+ dropout=dropout,
834
+ time_embedding_norm=resnet_time_scale_shift,
835
+ non_linearity=resnet_act_fn,
836
+ output_scale_factor=output_scale_factor,
837
+ pre_norm=resnet_pre_norm,
838
+ )
839
+ )
840
+ temp_convs.append(
841
+ TemporalConvLayer(
842
+ out_channels,
843
+ out_channels,
844
+ dropout=0.1
845
+ )
846
+ )
847
+
848
+ self.resnets = nn.ModuleList(resnets)
849
+ self.temp_convs = nn.ModuleList(temp_convs)
850
+
851
+ if add_upsample:
852
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
853
+ else:
854
+ self.upsamplers = None
855
+
856
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
857
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
858
+ # pop res hidden states
859
+ res_hidden_states = res_hidden_states_tuple[-1]
860
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
861
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
862
+
863
+ if self.gradient_checkpointing:
864
+ hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames)
865
+ else:
866
+ hidden_states = resnet(hidden_states, temb)
867
+
868
+ if num_frames > 1:
869
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
870
+
871
+ if self.upsamplers is not None:
872
+ for upsampler in self.upsamplers:
873
+ hidden_states = upsampler(hidden_states, upsample_size)
874
+
875
+ return hidden_states