heboya8 commited on
Commit
07f93c8
·
verified ·
1 Parent(s): c9e511a

Upload unet_3d_condition.py

Browse files
Files changed (1) hide show
  1. unet/unet_3d_condition.py +503 -0
unet/unet_3d_condition.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ # Original (causing error)
27
+ # from diffusers.models.transformer_temporal import TransformerTemporalModel
28
+ # Updated (hypothetical, check actual path)
29
+ from diffusers.models.temporal import TransformerTemporalModel
30
+ from .unet_3d_blocks import (
31
+ CrossAttnDownBlock3D,
32
+ CrossAttnUpBlock3D,
33
+ DownBlock3D,
34
+ UNetMidBlock3DCrossAttn,
35
+ UpBlock3D,
36
+ get_down_block,
37
+ get_up_block,
38
+ transformer_g_c
39
+ )
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ @dataclass
46
+ class UNet3DConditionOutput(BaseOutput):
47
+ """
48
+ Args:
49
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
50
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
51
+ """
52
+
53
+ sample: torch.FloatTensor
54
+
55
+
56
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
57
+ r"""
58
+ UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
59
+ and returns sample shaped output.
60
+
61
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
62
+ implements for all the models (such as downloading or saving, etc.)
63
+
64
+ Parameters:
65
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
66
+ Height and width of input/output sample.
67
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
68
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
69
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
70
+ The tuple of downsample blocks to use.
71
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
72
+ The tuple of upsample blocks to use.
73
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
74
+ The tuple of output channels for each block.
75
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
76
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
77
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
78
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
79
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
80
+ If `None`, it will skip the normalization and activation layers in post-processing
81
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
82
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
83
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
84
+ """
85
+
86
+ _supports_gradient_checkpointing = True
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ sample_size: Optional[int] = None,
92
+ in_channels: int = 4,
93
+ out_channels: int = 4,
94
+ down_block_types: Tuple[str] = (
95
+ "CrossAttnDownBlock3D",
96
+ "CrossAttnDownBlock3D",
97
+ "CrossAttnDownBlock3D",
98
+ "DownBlock3D",
99
+ ),
100
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
101
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
102
+ layers_per_block: int = 2,
103
+ downsample_padding: int = 1,
104
+ mid_block_scale_factor: float = 1,
105
+ act_fn: str = "silu",
106
+ norm_num_groups: Optional[int] = 32,
107
+ norm_eps: float = 1e-5,
108
+ cross_attention_dim: int = 1024,
109
+ attention_head_dim: Union[int, Tuple[int]] = 64,
110
+ ):
111
+ super().__init__()
112
+
113
+ self.sample_size = sample_size
114
+ self.gradient_checkpointing = False
115
+ # Check inputs
116
+ if len(down_block_types) != len(up_block_types):
117
+ raise ValueError(
118
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
119
+ )
120
+
121
+ if len(block_out_channels) != len(down_block_types):
122
+ raise ValueError(
123
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
124
+ )
125
+
126
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
127
+ raise ValueError(
128
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
129
+ )
130
+
131
+ # input
132
+ conv_in_kernel = 3
133
+ conv_out_kernel = 3
134
+ conv_in_padding = (conv_in_kernel - 1) // 2
135
+ self.conv_in = nn.Conv2d(
136
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
137
+ )
138
+
139
+ # time
140
+ time_embed_dim = block_out_channels[0] * 4
141
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
142
+ timestep_input_dim = block_out_channels[0]
143
+
144
+ self.time_embedding = TimestepEmbedding(
145
+ timestep_input_dim,
146
+ time_embed_dim,
147
+ act_fn=act_fn,
148
+ )
149
+
150
+ self.transformer_in = TransformerTemporalModel(
151
+ num_attention_heads=8,
152
+ attention_head_dim=attention_head_dim,
153
+ in_channels=block_out_channels[0],
154
+ num_layers=1,
155
+ )
156
+
157
+ # class embedding
158
+ self.down_blocks = nn.ModuleList([])
159
+ self.up_blocks = nn.ModuleList([])
160
+
161
+ if isinstance(attention_head_dim, int):
162
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
163
+
164
+ # down
165
+ output_channel = block_out_channels[0]
166
+ for i, down_block_type in enumerate(down_block_types):
167
+ input_channel = output_channel
168
+ output_channel = block_out_channels[i]
169
+ is_final_block = i == len(block_out_channels) - 1
170
+
171
+ down_block = get_down_block(
172
+ down_block_type,
173
+ num_layers=layers_per_block,
174
+ in_channels=input_channel,
175
+ out_channels=output_channel,
176
+ temb_channels=time_embed_dim,
177
+ add_downsample=not is_final_block,
178
+ resnet_eps=norm_eps,
179
+ resnet_act_fn=act_fn,
180
+ resnet_groups=norm_num_groups,
181
+ cross_attention_dim=cross_attention_dim,
182
+ attn_num_head_channels=attention_head_dim[i],
183
+ downsample_padding=downsample_padding,
184
+ dual_cross_attention=False,
185
+ )
186
+ self.down_blocks.append(down_block)
187
+
188
+ # mid
189
+ self.mid_block = UNetMidBlock3DCrossAttn(
190
+ in_channels=block_out_channels[-1],
191
+ temb_channels=time_embed_dim,
192
+ resnet_eps=norm_eps,
193
+ resnet_act_fn=act_fn,
194
+ output_scale_factor=mid_block_scale_factor,
195
+ cross_attention_dim=cross_attention_dim,
196
+ attn_num_head_channels=attention_head_dim[-1],
197
+ resnet_groups=norm_num_groups,
198
+ dual_cross_attention=False,
199
+ )
200
+
201
+ # count how many layers upsample the images
202
+ self.num_upsamplers = 0
203
+
204
+ # up
205
+ reversed_block_out_channels = list(reversed(block_out_channels))
206
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
207
+
208
+ output_channel = reversed_block_out_channels[0]
209
+ for i, up_block_type in enumerate(up_block_types):
210
+ is_final_block = i == len(block_out_channels) - 1
211
+
212
+ prev_output_channel = output_channel
213
+ output_channel = reversed_block_out_channels[i]
214
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
215
+
216
+ # add upsample block for all BUT final layer
217
+ if not is_final_block:
218
+ add_upsample = True
219
+ self.num_upsamplers += 1
220
+ else:
221
+ add_upsample = False
222
+
223
+ up_block = get_up_block(
224
+ up_block_type,
225
+ num_layers=layers_per_block + 1,
226
+ in_channels=input_channel,
227
+ out_channels=output_channel,
228
+ prev_output_channel=prev_output_channel,
229
+ temb_channels=time_embed_dim,
230
+ add_upsample=add_upsample,
231
+ resnet_eps=norm_eps,
232
+ resnet_act_fn=act_fn,
233
+ resnet_groups=norm_num_groups,
234
+ cross_attention_dim=cross_attention_dim,
235
+ attn_num_head_channels=reversed_attention_head_dim[i],
236
+ dual_cross_attention=False,
237
+ )
238
+ self.up_blocks.append(up_block)
239
+ prev_output_channel = output_channel
240
+
241
+ # out
242
+ if norm_num_groups is not None:
243
+ self.conv_norm_out = nn.GroupNorm(
244
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
245
+ )
246
+ self.conv_act = nn.SiLU()
247
+ else:
248
+ self.conv_norm_out = None
249
+ self.conv_act = None
250
+
251
+ conv_out_padding = (conv_out_kernel - 1) // 2
252
+ self.conv_out = nn.Conv2d(
253
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
254
+ )
255
+
256
+ def set_attention_slice(self, slice_size):
257
+ r"""
258
+ Enable sliced attention computation.
259
+
260
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
261
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
262
+
263
+ Args:
264
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
265
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
266
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
267
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
268
+ must be a multiple of `slice_size`.
269
+ """
270
+ sliceable_head_dims = []
271
+
272
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
273
+ if hasattr(module, "set_attention_slice"):
274
+ sliceable_head_dims.append(module.sliceable_head_dim)
275
+
276
+ for child in module.children():
277
+ fn_recursive_retrieve_slicable_dims(child)
278
+
279
+ # retrieve number of attention layers
280
+ for module in self.children():
281
+ fn_recursive_retrieve_slicable_dims(module)
282
+
283
+ num_slicable_layers = len(sliceable_head_dims)
284
+
285
+ if slice_size == "auto":
286
+ # half the attention head size is usually a good trade-off between
287
+ # speed and memory
288
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
289
+ elif slice_size == "max":
290
+ # make smallest slice possible
291
+ slice_size = num_slicable_layers * [1]
292
+
293
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
294
+
295
+ if len(slice_size) != len(sliceable_head_dims):
296
+ raise ValueError(
297
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
298
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
299
+ )
300
+
301
+ for i in range(len(slice_size)):
302
+ size = slice_size[i]
303
+ dim = sliceable_head_dims[i]
304
+ if size is not None and size > dim:
305
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
306
+
307
+ # Recursively walk through all the children.
308
+ # Any children which exposes the set_attention_slice method
309
+ # gets the message
310
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
311
+ if hasattr(module, "set_attention_slice"):
312
+ module.set_attention_slice(slice_size.pop())
313
+
314
+ for child in module.children():
315
+ fn_recursive_set_attention_slice(child, slice_size)
316
+
317
+ reversed_slice_size = list(reversed(slice_size))
318
+ for module in self.children():
319
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
320
+
321
+ def _set_gradient_checkpointing(self, value=False):
322
+ self.gradient_checkpointing = value
323
+ self.mid_block.gradient_checkpointing = value
324
+ for module in self.down_blocks + self.up_blocks:
325
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
326
+ module.gradient_checkpointing = value
327
+
328
+ def forward(
329
+ self,
330
+ sample: torch.FloatTensor,
331
+ timestep: Union[torch.Tensor, float, int],
332
+ encoder_hidden_states: torch.Tensor,
333
+ class_labels: Optional[torch.Tensor] = None,
334
+ timestep_cond: Optional[torch.Tensor] = None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
337
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
338
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
339
+ return_dict: bool = True,
340
+ ) -> Union[UNet3DConditionOutput, Tuple]:
341
+ r"""
342
+ Args:
343
+ sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
344
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
345
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
346
+ return_dict (`bool`, *optional*, defaults to `True`):
347
+ Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
348
+ cross_attention_kwargs (`dict`, *optional*):
349
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
350
+ `self.processor` in
351
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
352
+
353
+ Returns:
354
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
355
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
356
+ returning a tuple, the first element is the sample tensor.
357
+ """
358
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
359
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
360
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
361
+ # on the fly if necessary.
362
+ default_overall_up_factor = 2**self.num_upsamplers
363
+
364
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
365
+ forward_upsample_size = False
366
+ upsample_size = None
367
+
368
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
369
+ logger.info("Forward upsample size to force interpolation output size.")
370
+ forward_upsample_size = True
371
+
372
+ # prepare attention_mask
373
+ if attention_mask is not None:
374
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
375
+ attention_mask = attention_mask.unsqueeze(1)
376
+
377
+ # 1. time
378
+ timesteps = timestep
379
+ if not torch.is_tensor(timesteps):
380
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
381
+ # This would be a good case for the `match` statement (Python 3.10+)
382
+ is_mps = sample.device.type == "mps"
383
+ if isinstance(timestep, float):
384
+ dtype = torch.float32 if is_mps else torch.float64
385
+ else:
386
+ dtype = torch.int32 if is_mps else torch.int64
387
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
388
+ elif len(timesteps.shape) == 0:
389
+ timesteps = timesteps[None].to(sample.device)
390
+
391
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
392
+ num_frames = sample.shape[2]
393
+ timesteps = timesteps.expand(sample.shape[0])
394
+
395
+ t_emb = self.time_proj(timesteps)
396
+
397
+ # timesteps does not contain any weights and will always return f32 tensors
398
+ # but time_embedding might actually be running in fp16. so we need to cast here.
399
+ # there might be better ways to encapsulate this.
400
+ t_emb = t_emb.to(dtype=self.dtype)
401
+
402
+ emb = self.time_embedding(t_emb, timestep_cond)
403
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
404
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
405
+
406
+ # 2. pre-process
407
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
408
+ sample = self.conv_in(sample)
409
+
410
+ if num_frames > 1:
411
+ if self.gradient_checkpointing:
412
+ sample = transformer_g_c(self.transformer_in, sample, num_frames)
413
+ else:
414
+ sample = self.transformer_in(sample, num_frames=num_frames).sample
415
+
416
+ # 3. down
417
+ down_block_res_samples = (sample,)
418
+ for downsample_block in self.down_blocks:
419
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
420
+ sample, res_samples = downsample_block(
421
+ hidden_states=sample,
422
+ temb=emb,
423
+ encoder_hidden_states=encoder_hidden_states,
424
+ attention_mask=attention_mask,
425
+ num_frames=num_frames,
426
+ cross_attention_kwargs=cross_attention_kwargs,
427
+ )
428
+ else:
429
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
430
+
431
+ down_block_res_samples += res_samples
432
+
433
+ if down_block_additional_residuals is not None:
434
+ new_down_block_res_samples = ()
435
+
436
+ for down_block_res_sample, down_block_additional_residual in zip(
437
+ down_block_res_samples, down_block_additional_residuals
438
+ ):
439
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
440
+ new_down_block_res_samples += (down_block_res_sample,)
441
+
442
+ down_block_res_samples = new_down_block_res_samples
443
+
444
+ # 4. mid
445
+ if self.mid_block is not None:
446
+ sample = self.mid_block(
447
+ sample,
448
+ emb,
449
+ encoder_hidden_states=encoder_hidden_states,
450
+ attention_mask=attention_mask,
451
+ num_frames=num_frames,
452
+ cross_attention_kwargs=cross_attention_kwargs,
453
+ )
454
+
455
+ if mid_block_additional_residual is not None:
456
+ sample = sample + mid_block_additional_residual
457
+
458
+ # 5. up
459
+ for i, upsample_block in enumerate(self.up_blocks):
460
+ is_final_block = i == len(self.up_blocks) - 1
461
+
462
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
463
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
464
+
465
+ # if we have not reached the final block and need to forward the
466
+ # upsample size, we do it here
467
+ if not is_final_block and forward_upsample_size:
468
+ upsample_size = down_block_res_samples[-1].shape[2:]
469
+
470
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
471
+ sample = upsample_block(
472
+ hidden_states=sample,
473
+ temb=emb,
474
+ res_hidden_states_tuple=res_samples,
475
+ encoder_hidden_states=encoder_hidden_states,
476
+ upsample_size=upsample_size,
477
+ attention_mask=attention_mask,
478
+ num_frames=num_frames,
479
+ cross_attention_kwargs=cross_attention_kwargs,
480
+ )
481
+ else:
482
+ sample = upsample_block(
483
+ hidden_states=sample,
484
+ temb=emb,
485
+ res_hidden_states_tuple=res_samples,
486
+ upsample_size=upsample_size,
487
+ num_frames=num_frames,
488
+ )
489
+
490
+ # 6. post-process
491
+ if self.conv_norm_out:
492
+ sample = self.conv_norm_out(sample)
493
+ sample = self.conv_act(sample)
494
+
495
+ sample = self.conv_out(sample)
496
+
497
+ # reshape to (batch, channel, framerate, width, height)
498
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
499
+
500
+ if not return_dict:
501
+ return (sample,)
502
+
503
+ return UNet3DConditionOutput(sample=sample)