camenduru commited on
Commit
4a273c0
·
1 Parent(s): 882b5a4

Delete unet_3d_blocks.py

Browse files
Files changed (1) hide show
  1. unet_3d_blocks.py +0 -842
unet_3d_blocks.py DELETED
@@ -1,842 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import torch
16
- import torch.utils.checkpoint as checkpoint
17
- from torch import nn
18
- from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
19
- from diffusers.models.transformer_2d import Transformer2DModel
20
- from diffusers.models.transformer_temporal import TransformerTemporalModel
21
-
22
- # Assign gradient checkpoint function to simple variable for readability.
23
- g_c = checkpoint.checkpoint
24
-
25
- def use_temporal(module, num_frames, x):
26
- if num_frames == 1:
27
- if isinstance(module, TransformerTemporalModel):
28
- return {"sample": x}
29
- else:
30
- return x
31
-
32
- def custom_checkpoint(module, mode=None):
33
- if mode == None: raise ValueError('Mode for gradient checkpointing cannot be none.')
34
- custom_forward = None
35
-
36
- if mode == 'resnet':
37
- def custom_forward(hidden_states, temb):
38
- inputs = module(hidden_states, temb)
39
- return inputs
40
-
41
- if mode == 'attn':
42
- def custom_forward(
43
- hidden_states,
44
- encoder_hidden_states=None,
45
- cross_attention_kwargs=None
46
- ):
47
- inputs = module(
48
- hidden_states,
49
- encoder_hidden_states,
50
- cross_attention_kwargs
51
- )
52
- return inputs
53
-
54
- if mode == 'temp':
55
- def custom_forward(hidden_states, num_frames=None):
56
- inputs = use_temporal(module, num_frames, hidden_states)
57
- if inputs is None: inputs = module(
58
- hidden_states,
59
- num_frames=num_frames
60
- )
61
- return inputs
62
-
63
- return custom_forward
64
-
65
- def transformer_g_c(transformer, sample, num_frames):
66
- sample = g_c(custom_checkpoint(transformer, mode='temp'),
67
- sample, num_frames, use_reentrant=False
68
- )['sample']
69
-
70
- return sample
71
-
72
- def cross_attn_g_c(
73
- attn,
74
- temp_attn,
75
- resnet,
76
- temp_conv,
77
- hidden_states,
78
- encoder_hidden_states,
79
- cross_attention_kwargs,
80
- temb,
81
- num_frames,
82
- inverse_temp=False
83
- ):
84
-
85
- def ordered_g_c(idx):
86
-
87
- # Self and CrossAttention
88
- if idx == 0: return g_c(custom_checkpoint(attn, mode='attn'),
89
- hidden_states, encoder_hidden_states,cross_attention_kwargs, use_reentrant=False
90
- )['sample']
91
-
92
- # Temporal Self and CrossAttention
93
- if idx == 1: return g_c(custom_checkpoint(temp_attn, mode='temp'),
94
- hidden_states, num_frames, use_reentrant=False)['sample']
95
-
96
- # Resnets
97
- if idx == 2: return g_c(custom_checkpoint(resnet, mode='resnet'),
98
- hidden_states, temb, use_reentrant=False)
99
-
100
- # Temporal Convolutions
101
- if idx == 3: return g_c(custom_checkpoint(temp_conv, mode='temp'),
102
- hidden_states, num_frames, use_reentrant=False
103
- )
104
-
105
- # Here we call the function depending on the order in which they are called.
106
- # For some layers, the orders are different, so we access the appropriate one by index.
107
-
108
- if not inverse_temp:
109
- for idx in [0,1,2,3]: hidden_states = ordered_g_c(idx)
110
- else:
111
- for idx in [2,3,0,1]: hidden_states = ordered_g_c(idx)
112
-
113
- return hidden_states
114
-
115
- def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames):
116
- hidden_states = g_c(custom_checkpoint(resnet, mode='resnet'), hidden_states, temb, use_reentrant=False)
117
- hidden_states = g_c(custom_checkpoint(temp_conv, mode='temp'),
118
- hidden_states, num_frames, use_reentrant=False
119
- )
120
- return hidden_states
121
-
122
- def get_down_block(
123
- down_block_type,
124
- num_layers,
125
- in_channels,
126
- out_channels,
127
- temb_channels,
128
- add_downsample,
129
- resnet_eps,
130
- resnet_act_fn,
131
- attn_num_head_channels,
132
- resnet_groups=None,
133
- cross_attention_dim=None,
134
- downsample_padding=None,
135
- dual_cross_attention=False,
136
- use_linear_projection=True,
137
- only_cross_attention=False,
138
- upcast_attention=False,
139
- resnet_time_scale_shift="default",
140
- ):
141
- if down_block_type == "DownBlock3D":
142
- return DownBlock3D(
143
- num_layers=num_layers,
144
- in_channels=in_channels,
145
- out_channels=out_channels,
146
- temb_channels=temb_channels,
147
- add_downsample=add_downsample,
148
- resnet_eps=resnet_eps,
149
- resnet_act_fn=resnet_act_fn,
150
- resnet_groups=resnet_groups,
151
- downsample_padding=downsample_padding,
152
- resnet_time_scale_shift=resnet_time_scale_shift,
153
- )
154
- elif down_block_type == "CrossAttnDownBlock3D":
155
- if cross_attention_dim is None:
156
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
157
- return CrossAttnDownBlock3D(
158
- num_layers=num_layers,
159
- in_channels=in_channels,
160
- out_channels=out_channels,
161
- temb_channels=temb_channels,
162
- add_downsample=add_downsample,
163
- resnet_eps=resnet_eps,
164
- resnet_act_fn=resnet_act_fn,
165
- resnet_groups=resnet_groups,
166
- downsample_padding=downsample_padding,
167
- cross_attention_dim=cross_attention_dim,
168
- attn_num_head_channels=attn_num_head_channels,
169
- dual_cross_attention=dual_cross_attention,
170
- use_linear_projection=use_linear_projection,
171
- only_cross_attention=only_cross_attention,
172
- upcast_attention=upcast_attention,
173
- resnet_time_scale_shift=resnet_time_scale_shift,
174
- )
175
- raise ValueError(f"{down_block_type} does not exist.")
176
-
177
-
178
- def get_up_block(
179
- up_block_type,
180
- num_layers,
181
- in_channels,
182
- out_channels,
183
- prev_output_channel,
184
- temb_channels,
185
- add_upsample,
186
- resnet_eps,
187
- resnet_act_fn,
188
- attn_num_head_channels,
189
- resnet_groups=None,
190
- cross_attention_dim=None,
191
- dual_cross_attention=False,
192
- use_linear_projection=True,
193
- only_cross_attention=False,
194
- upcast_attention=False,
195
- resnet_time_scale_shift="default",
196
- ):
197
- if up_block_type == "UpBlock3D":
198
- return UpBlock3D(
199
- num_layers=num_layers,
200
- in_channels=in_channels,
201
- out_channels=out_channels,
202
- prev_output_channel=prev_output_channel,
203
- temb_channels=temb_channels,
204
- add_upsample=add_upsample,
205
- resnet_eps=resnet_eps,
206
- resnet_act_fn=resnet_act_fn,
207
- resnet_groups=resnet_groups,
208
- resnet_time_scale_shift=resnet_time_scale_shift,
209
- )
210
- elif up_block_type == "CrossAttnUpBlock3D":
211
- if cross_attention_dim is None:
212
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
213
- return CrossAttnUpBlock3D(
214
- num_layers=num_layers,
215
- in_channels=in_channels,
216
- out_channels=out_channels,
217
- prev_output_channel=prev_output_channel,
218
- temb_channels=temb_channels,
219
- add_upsample=add_upsample,
220
- resnet_eps=resnet_eps,
221
- resnet_act_fn=resnet_act_fn,
222
- resnet_groups=resnet_groups,
223
- cross_attention_dim=cross_attention_dim,
224
- attn_num_head_channels=attn_num_head_channels,
225
- dual_cross_attention=dual_cross_attention,
226
- use_linear_projection=use_linear_projection,
227
- only_cross_attention=only_cross_attention,
228
- upcast_attention=upcast_attention,
229
- resnet_time_scale_shift=resnet_time_scale_shift,
230
- )
231
- raise ValueError(f"{up_block_type} does not exist.")
232
-
233
-
234
- class UNetMidBlock3DCrossAttn(nn.Module):
235
- def __init__(
236
- self,
237
- in_channels: int,
238
- temb_channels: int,
239
- dropout: float = 0.0,
240
- num_layers: int = 1,
241
- resnet_eps: float = 1e-6,
242
- resnet_time_scale_shift: str = "default",
243
- resnet_act_fn: str = "swish",
244
- resnet_groups: int = 32,
245
- resnet_pre_norm: bool = True,
246
- attn_num_head_channels=1,
247
- output_scale_factor=1.0,
248
- cross_attention_dim=1280,
249
- dual_cross_attention=False,
250
- use_linear_projection=True,
251
- upcast_attention=False,
252
- ):
253
- super().__init__()
254
-
255
- self.gradient_checkpointing = False
256
- self.has_cross_attention = True
257
- self.attn_num_head_channels = attn_num_head_channels
258
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
259
-
260
- # there is always at least one resnet
261
- resnets = [
262
- ResnetBlock2D(
263
- in_channels=in_channels,
264
- out_channels=in_channels,
265
- temb_channels=temb_channels,
266
- eps=resnet_eps,
267
- groups=resnet_groups,
268
- dropout=dropout,
269
- time_embedding_norm=resnet_time_scale_shift,
270
- non_linearity=resnet_act_fn,
271
- output_scale_factor=output_scale_factor,
272
- pre_norm=resnet_pre_norm,
273
- )
274
- ]
275
- temp_convs = [
276
- TemporalConvLayer(
277
- in_channels,
278
- in_channels,
279
- dropout=0.1
280
- )
281
- ]
282
- attentions = []
283
- temp_attentions = []
284
-
285
- for _ in range(num_layers):
286
- attentions.append(
287
- Transformer2DModel(
288
- in_channels // attn_num_head_channels,
289
- attn_num_head_channels,
290
- in_channels=in_channels,
291
- num_layers=1,
292
- cross_attention_dim=cross_attention_dim,
293
- norm_num_groups=resnet_groups,
294
- use_linear_projection=use_linear_projection,
295
- upcast_attention=upcast_attention,
296
- )
297
- )
298
- temp_attentions.append(
299
- TransformerTemporalModel(
300
- in_channels // attn_num_head_channels,
301
- attn_num_head_channels,
302
- in_channels=in_channels,
303
- num_layers=1,
304
- cross_attention_dim=cross_attention_dim,
305
- norm_num_groups=resnet_groups,
306
- )
307
- )
308
- resnets.append(
309
- ResnetBlock2D(
310
- in_channels=in_channels,
311
- out_channels=in_channels,
312
- temb_channels=temb_channels,
313
- eps=resnet_eps,
314
- groups=resnet_groups,
315
- dropout=dropout,
316
- time_embedding_norm=resnet_time_scale_shift,
317
- non_linearity=resnet_act_fn,
318
- output_scale_factor=output_scale_factor,
319
- pre_norm=resnet_pre_norm,
320
- )
321
- )
322
- temp_convs.append(
323
- TemporalConvLayer(
324
- in_channels,
325
- in_channels,
326
- dropout=0.1
327
- )
328
- )
329
-
330
- self.resnets = nn.ModuleList(resnets)
331
- self.temp_convs = nn.ModuleList(temp_convs)
332
- self.attentions = nn.ModuleList(attentions)
333
- self.temp_attentions = nn.ModuleList(temp_attentions)
334
-
335
- def forward(
336
- self,
337
- hidden_states,
338
- temb=None,
339
- encoder_hidden_states=None,
340
- attention_mask=None,
341
- num_frames=1,
342
- cross_attention_kwargs=None,
343
- ):
344
- if self.gradient_checkpointing:
345
- hidden_states = up_down_g_c(
346
- self.resnets[0],
347
- self.temp_convs[0],
348
- hidden_states,
349
- temb,
350
- num_frames
351
- )
352
- else:
353
- hidden_states = self.resnets[0](hidden_states, temb)
354
- hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
355
-
356
- for attn, temp_attn, resnet, temp_conv in zip(
357
- self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
358
- ):
359
- if self.gradient_checkpointing:
360
- hidden_states = cross_attn_g_c(
361
- attn,
362
- temp_attn,
363
- resnet,
364
- temp_conv,
365
- hidden_states,
366
- encoder_hidden_states,
367
- cross_attention_kwargs,
368
- temb,
369
- num_frames
370
- )
371
- else:
372
- hidden_states = attn(
373
- hidden_states,
374
- encoder_hidden_states=encoder_hidden_states,
375
- cross_attention_kwargs=cross_attention_kwargs,
376
- ).sample
377
-
378
- if num_frames > 1:
379
- hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
380
-
381
- hidden_states = resnet(hidden_states, temb)
382
-
383
- if num_frames > 1:
384
- hidden_states = temp_conv(hidden_states, num_frames=num_frames)
385
-
386
- return hidden_states
387
-
388
-
389
- class CrossAttnDownBlock3D(nn.Module):
390
- def __init__(
391
- self,
392
- in_channels: int,
393
- out_channels: int,
394
- temb_channels: int,
395
- dropout: float = 0.0,
396
- num_layers: int = 1,
397
- resnet_eps: float = 1e-6,
398
- resnet_time_scale_shift: str = "default",
399
- resnet_act_fn: str = "swish",
400
- resnet_groups: int = 32,
401
- resnet_pre_norm: bool = True,
402
- attn_num_head_channels=1,
403
- cross_attention_dim=1280,
404
- output_scale_factor=1.0,
405
- downsample_padding=1,
406
- add_downsample=True,
407
- dual_cross_attention=False,
408
- use_linear_projection=False,
409
- only_cross_attention=False,
410
- upcast_attention=False,
411
- ):
412
- super().__init__()
413
- resnets = []
414
- attentions = []
415
- temp_attentions = []
416
- temp_convs = []
417
-
418
- self.gradient_checkpointing = False
419
- self.has_cross_attention = True
420
- self.attn_num_head_channels = attn_num_head_channels
421
-
422
- for i in range(num_layers):
423
- in_channels = in_channels if i == 0 else out_channels
424
- resnets.append(
425
- ResnetBlock2D(
426
- in_channels=in_channels,
427
- out_channels=out_channels,
428
- temb_channels=temb_channels,
429
- eps=resnet_eps,
430
- groups=resnet_groups,
431
- dropout=dropout,
432
- time_embedding_norm=resnet_time_scale_shift,
433
- non_linearity=resnet_act_fn,
434
- output_scale_factor=output_scale_factor,
435
- pre_norm=resnet_pre_norm,
436
- )
437
- )
438
- temp_convs.append(
439
- TemporalConvLayer(
440
- out_channels,
441
- out_channels,
442
- dropout=0.1
443
- )
444
- )
445
- attentions.append(
446
- Transformer2DModel(
447
- out_channels // attn_num_head_channels,
448
- attn_num_head_channels,
449
- in_channels=out_channels,
450
- num_layers=1,
451
- cross_attention_dim=cross_attention_dim,
452
- norm_num_groups=resnet_groups,
453
- use_linear_projection=use_linear_projection,
454
- only_cross_attention=only_cross_attention,
455
- upcast_attention=upcast_attention,
456
- )
457
- )
458
- temp_attentions.append(
459
- TransformerTemporalModel(
460
- out_channels // attn_num_head_channels,
461
- attn_num_head_channels,
462
- in_channels=out_channels,
463
- num_layers=1,
464
- cross_attention_dim=cross_attention_dim,
465
- norm_num_groups=resnet_groups,
466
- )
467
- )
468
- self.resnets = nn.ModuleList(resnets)
469
- self.temp_convs = nn.ModuleList(temp_convs)
470
- self.attentions = nn.ModuleList(attentions)
471
- self.temp_attentions = nn.ModuleList(temp_attentions)
472
-
473
- if add_downsample:
474
- self.downsamplers = nn.ModuleList(
475
- [
476
- Downsample2D(
477
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
478
- )
479
- ]
480
- )
481
- else:
482
- self.downsamplers = None
483
-
484
- def forward(
485
- self,
486
- hidden_states,
487
- temb=None,
488
- encoder_hidden_states=None,
489
- attention_mask=None,
490
- num_frames=1,
491
- cross_attention_kwargs=None,
492
- ):
493
- # TODO(Patrick, William) - attention mask is not used
494
- output_states = ()
495
-
496
- for resnet, temp_conv, attn, temp_attn in zip(
497
- self.resnets, self.temp_convs, self.attentions, self.temp_attentions
498
- ):
499
-
500
- if self.gradient_checkpointing:
501
- hidden_states = cross_attn_g_c(
502
- attn,
503
- temp_attn,
504
- resnet,
505
- temp_conv,
506
- hidden_states,
507
- encoder_hidden_states,
508
- cross_attention_kwargs,
509
- temb,
510
- num_frames,
511
- inverse_temp=True
512
- )
513
- else:
514
- hidden_states = resnet(hidden_states, temb)
515
-
516
- if num_frames > 1:
517
- hidden_states = temp_conv(hidden_states, num_frames=num_frames)
518
-
519
- hidden_states = attn(
520
- hidden_states,
521
- encoder_hidden_states=encoder_hidden_states,
522
- cross_attention_kwargs=cross_attention_kwargs,
523
- ).sample
524
-
525
- if num_frames > 1:
526
- hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
527
-
528
- output_states += (hidden_states,)
529
-
530
- if self.downsamplers is not None:
531
- for downsampler in self.downsamplers:
532
- hidden_states = downsampler(hidden_states)
533
-
534
- output_states += (hidden_states,)
535
-
536
- return hidden_states, output_states
537
-
538
-
539
- class DownBlock3D(nn.Module):
540
- def __init__(
541
- self,
542
- in_channels: int,
543
- out_channels: int,
544
- temb_channels: int,
545
- dropout: float = 0.0,
546
- num_layers: int = 1,
547
- resnet_eps: float = 1e-6,
548
- resnet_time_scale_shift: str = "default",
549
- resnet_act_fn: str = "swish",
550
- resnet_groups: int = 32,
551
- resnet_pre_norm: bool = True,
552
- output_scale_factor=1.0,
553
- add_downsample=True,
554
- downsample_padding=1,
555
- ):
556
- super().__init__()
557
- resnets = []
558
- temp_convs = []
559
-
560
- self.gradient_checkpointing = False
561
- for i in range(num_layers):
562
- in_channels = in_channels if i == 0 else out_channels
563
- resnets.append(
564
- ResnetBlock2D(
565
- in_channels=in_channels,
566
- out_channels=out_channels,
567
- temb_channels=temb_channels,
568
- eps=resnet_eps,
569
- groups=resnet_groups,
570
- dropout=dropout,
571
- time_embedding_norm=resnet_time_scale_shift,
572
- non_linearity=resnet_act_fn,
573
- output_scale_factor=output_scale_factor,
574
- pre_norm=resnet_pre_norm,
575
- )
576
- )
577
- temp_convs.append(
578
- TemporalConvLayer(
579
- out_channels,
580
- out_channels,
581
- dropout=0.1
582
- )
583
- )
584
-
585
- self.resnets = nn.ModuleList(resnets)
586
- self.temp_convs = nn.ModuleList(temp_convs)
587
-
588
- if add_downsample:
589
- self.downsamplers = nn.ModuleList(
590
- [
591
- Downsample2D(
592
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
593
- )
594
- ]
595
- )
596
- else:
597
- self.downsamplers = None
598
-
599
- def forward(self, hidden_states, temb=None, num_frames=1):
600
- output_states = ()
601
-
602
- for resnet, temp_conv in zip(self.resnets, self.temp_convs):
603
- if self.gradient_checkpointing:
604
- hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames)
605
- else:
606
- hidden_states = resnet(hidden_states, temb)
607
-
608
- if num_frames > 1:
609
- hidden_states = temp_conv(hidden_states, num_frames=num_frames)
610
-
611
- output_states += (hidden_states,)
612
-
613
- if self.downsamplers is not None:
614
- for downsampler in self.downsamplers:
615
- hidden_states = downsampler(hidden_states)
616
-
617
- output_states += (hidden_states,)
618
-
619
- return hidden_states, output_states
620
-
621
-
622
- class CrossAttnUpBlock3D(nn.Module):
623
- def __init__(
624
- self,
625
- in_channels: int,
626
- out_channels: int,
627
- prev_output_channel: int,
628
- temb_channels: int,
629
- dropout: float = 0.0,
630
- num_layers: int = 1,
631
- resnet_eps: float = 1e-6,
632
- resnet_time_scale_shift: str = "default",
633
- resnet_act_fn: str = "swish",
634
- resnet_groups: int = 32,
635
- resnet_pre_norm: bool = True,
636
- attn_num_head_channels=1,
637
- cross_attention_dim=1280,
638
- output_scale_factor=1.0,
639
- add_upsample=True,
640
- dual_cross_attention=False,
641
- use_linear_projection=False,
642
- only_cross_attention=False,
643
- upcast_attention=False,
644
- ):
645
- super().__init__()
646
- resnets = []
647
- temp_convs = []
648
- attentions = []
649
- temp_attentions = []
650
-
651
- self.gradient_checkpointing = False
652
- self.has_cross_attention = True
653
- self.attn_num_head_channels = attn_num_head_channels
654
-
655
- for i in range(num_layers):
656
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
657
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
658
-
659
- resnets.append(
660
- ResnetBlock2D(
661
- in_channels=resnet_in_channels + res_skip_channels,
662
- out_channels=out_channels,
663
- temb_channels=temb_channels,
664
- eps=resnet_eps,
665
- groups=resnet_groups,
666
- dropout=dropout,
667
- time_embedding_norm=resnet_time_scale_shift,
668
- non_linearity=resnet_act_fn,
669
- output_scale_factor=output_scale_factor,
670
- pre_norm=resnet_pre_norm,
671
- )
672
- )
673
- temp_convs.append(
674
- TemporalConvLayer(
675
- out_channels,
676
- out_channels,
677
- dropout=0.1
678
- )
679
- )
680
- attentions.append(
681
- Transformer2DModel(
682
- out_channels // attn_num_head_channels,
683
- attn_num_head_channels,
684
- in_channels=out_channels,
685
- num_layers=1,
686
- cross_attention_dim=cross_attention_dim,
687
- norm_num_groups=resnet_groups,
688
- use_linear_projection=use_linear_projection,
689
- only_cross_attention=only_cross_attention,
690
- upcast_attention=upcast_attention,
691
- )
692
- )
693
- temp_attentions.append(
694
- TransformerTemporalModel(
695
- out_channels // attn_num_head_channels,
696
- attn_num_head_channels,
697
- in_channels=out_channels,
698
- num_layers=1,
699
- cross_attention_dim=cross_attention_dim,
700
- norm_num_groups=resnet_groups,
701
- )
702
- )
703
- self.resnets = nn.ModuleList(resnets)
704
- self.temp_convs = nn.ModuleList(temp_convs)
705
- self.attentions = nn.ModuleList(attentions)
706
- self.temp_attentions = nn.ModuleList(temp_attentions)
707
-
708
- if add_upsample:
709
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
710
- else:
711
- self.upsamplers = None
712
-
713
- def forward(
714
- self,
715
- hidden_states,
716
- res_hidden_states_tuple,
717
- temb=None,
718
- encoder_hidden_states=None,
719
- upsample_size=None,
720
- attention_mask=None,
721
- num_frames=1,
722
- cross_attention_kwargs=None,
723
- ):
724
- # TODO(Patrick, William) - attention mask is not used
725
- for resnet, temp_conv, attn, temp_attn in zip(
726
- self.resnets, self.temp_convs, self.attentions, self.temp_attentions
727
- ):
728
- # pop res hidden states
729
- res_hidden_states = res_hidden_states_tuple[-1]
730
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
731
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
732
-
733
- if self.gradient_checkpointing:
734
- hidden_states = cross_attn_g_c(
735
- attn,
736
- temp_attn,
737
- resnet,
738
- temp_conv,
739
- hidden_states,
740
- encoder_hidden_states,
741
- cross_attention_kwargs,
742
- temb,
743
- num_frames,
744
- inverse_temp=True
745
- )
746
- else:
747
- hidden_states = resnet(hidden_states, temb)
748
-
749
- if num_frames > 1:
750
- hidden_states = temp_conv(hidden_states, num_frames=num_frames)
751
-
752
- hidden_states = attn(
753
- hidden_states,
754
- encoder_hidden_states=encoder_hidden_states,
755
- cross_attention_kwargs=cross_attention_kwargs,
756
- ).sample
757
-
758
- if num_frames > 1:
759
- hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
760
-
761
- if self.upsamplers is not None:
762
- for upsampler in self.upsamplers:
763
- hidden_states = upsampler(hidden_states, upsample_size)
764
-
765
- return hidden_states
766
-
767
-
768
- class UpBlock3D(nn.Module):
769
- def __init__(
770
- self,
771
- in_channels: int,
772
- prev_output_channel: int,
773
- out_channels: int,
774
- temb_channels: int,
775
- dropout: float = 0.0,
776
- num_layers: int = 1,
777
- resnet_eps: float = 1e-6,
778
- resnet_time_scale_shift: str = "default",
779
- resnet_act_fn: str = "swish",
780
- resnet_groups: int = 32,
781
- resnet_pre_norm: bool = True,
782
- output_scale_factor=1.0,
783
- add_upsample=True,
784
- ):
785
- super().__init__()
786
- resnets = []
787
- temp_convs = []
788
- self.gradient_checkpointing = False
789
- for i in range(num_layers):
790
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
791
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
792
-
793
- resnets.append(
794
- ResnetBlock2D(
795
- in_channels=resnet_in_channels + res_skip_channels,
796
- out_channels=out_channels,
797
- temb_channels=temb_channels,
798
- eps=resnet_eps,
799
- groups=resnet_groups,
800
- dropout=dropout,
801
- time_embedding_norm=resnet_time_scale_shift,
802
- non_linearity=resnet_act_fn,
803
- output_scale_factor=output_scale_factor,
804
- pre_norm=resnet_pre_norm,
805
- )
806
- )
807
- temp_convs.append(
808
- TemporalConvLayer(
809
- out_channels,
810
- out_channels,
811
- dropout=0.1
812
- )
813
- )
814
-
815
- self.resnets = nn.ModuleList(resnets)
816
- self.temp_convs = nn.ModuleList(temp_convs)
817
-
818
- if add_upsample:
819
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
820
- else:
821
- self.upsamplers = None
822
-
823
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
824
- for resnet, temp_conv in zip(self.resnets, self.temp_convs):
825
- # pop res hidden states
826
- res_hidden_states = res_hidden_states_tuple[-1]
827
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
828
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
829
-
830
- if self.gradient_checkpointing:
831
- hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames)
832
- else:
833
- hidden_states = resnet(hidden_states, temb)
834
-
835
- if num_frames > 1:
836
- hidden_states = temp_conv(hidden_states, num_frames=num_frames)
837
-
838
- if self.upsamplers is not None:
839
- for upsampler in self.upsamplers:
840
- hidden_states = upsampler(hidden_states, upsample_size)
841
-
842
- return hidden_states