xixircc commited on
Commit
b5f47ea
·
verified ·
1 Parent(s): 644034f

Delete folder models with huggingface_hub

Browse files
Files changed (2) hide show
  1. models/unet_3d.py +0 -727
  2. models/unet_3d_blocks.py +0 -1121
models/unet_3d.py DELETED
@@ -1,727 +0,0 @@
1
- # *************************************************************************
2
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # This file has been modified by ByteDance Ltd. and/or its affiliates.
6
- #
7
- # Original file was released under Aniportrait, with the full license text
8
- # available at https://github.com/Zejun-Yang/AniPortrait/blob/main/LICENSE.
9
- #
10
- # This modified file is released under the same license.
11
- # *************************************************************************
12
- from collections import OrderedDict
13
- from dataclasses import dataclass
14
- import pdb
15
- from os import PathLike
16
- from pathlib import Path
17
- from typing import Dict, List, Optional, Tuple, Union
18
-
19
- import torch
20
- import torch.nn as nn
21
- import torch.utils.checkpoint
22
- import torch.nn.functional as F
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.models.attention_processor import AttentionProcessor
25
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
26
- from diffusers.models.modeling_utils import ModelMixin
27
- from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
28
- from safetensors.torch import load_file
29
-
30
- from .resnet import InflatedConv3d, InflatedGroupNorm
31
- from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
32
-
33
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
-
35
-
36
- @dataclass
37
- class UNet3DConditionOutput(BaseOutput):
38
- sample: torch.FloatTensor
39
-
40
-
41
- class UNet3DConditionModel(ModelMixin, ConfigMixin):
42
- _supports_gradient_checkpointing = True
43
-
44
- @register_to_config
45
- def __init__(
46
- self,
47
- sample_size: Optional[int] = None,
48
- in_channels: int = 4,
49
- out_channels: int = 4,
50
- center_input_sample: bool = False,
51
- flip_sin_to_cos: bool = True,
52
- freq_shift: int = 0,
53
- down_block_types: Tuple[str] = (
54
- "CrossAttnDownBlock3D",
55
- "CrossAttnDownBlock3D",
56
- "CrossAttnDownBlock3D",
57
- "DownBlock3D",
58
- ),
59
- mid_block_type: str = "UNetMidBlock3DCrossAttn",
60
- up_block_types: Tuple[str] = (
61
- "UpBlock3D",
62
- "CrossAttnUpBlock3D",
63
- "CrossAttnUpBlock3D",
64
- "CrossAttnUpBlock3D",
65
- ),
66
- only_cross_attention: Union[bool, Tuple[bool]] = False,
67
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
68
- layers_per_block: int = 2,
69
- downsample_padding: int = 1,
70
- mid_block_scale_factor: float = 1,
71
- act_fn: str = "silu",
72
- norm_num_groups: int = 32,
73
- norm_eps: float = 1e-5,
74
- cross_attention_dim: int = 1280,
75
- attention_head_dim: Union[int, Tuple[int]] = 8,
76
- dual_cross_attention: bool = False,
77
- use_linear_projection: bool = False,
78
- class_embed_type: Optional[str] = None,
79
- num_class_embeds: Optional[int] = None,
80
- upcast_attention: bool = False,
81
- resnet_time_scale_shift: str = "default",
82
- use_inflated_groupnorm=False,
83
- # Additional
84
- use_motion_module=False,
85
- use_temporal_module=False,
86
- motion_module_resolutions=(1, 2, 4, 8),
87
- motion_module_mid_block=False,
88
- motion_module_decoder_only=False,
89
- motion_module_type=None,
90
- temporal_module_type=None,
91
- motion_module_kwargs={},
92
- temporal_module_kwargs={},
93
- unet_use_cross_frame_attention=None,
94
- unet_use_temporal_attention=None,
95
- ):
96
- super().__init__()
97
-
98
- self.sample_size = sample_size
99
- time_embed_dim = block_out_channels[0] * 4
100
-
101
- # input
102
- self.conv_in = InflatedConv3d(
103
- in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
104
- )
105
-
106
- # time
107
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
108
- timestep_input_dim = block_out_channels[0]
109
-
110
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
111
-
112
- # class embedding
113
- if class_embed_type is None and num_class_embeds is not None:
114
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
115
- elif class_embed_type == "timestep":
116
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
117
- elif class_embed_type == "identity":
118
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
119
- else:
120
- self.class_embedding = None
121
-
122
- self.down_blocks = nn.ModuleList([])
123
- self.mid_block = None
124
- self.up_blocks = nn.ModuleList([])
125
-
126
- if isinstance(only_cross_attention, bool):
127
- only_cross_attention = [only_cross_attention] * len(down_block_types)
128
-
129
- if isinstance(attention_head_dim, int):
130
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
131
-
132
- # down
133
- output_channel = block_out_channels[0]
134
- for i, down_block_type in enumerate(down_block_types):
135
- res = 2**i
136
- input_channel = output_channel
137
- output_channel = block_out_channels[i]
138
- is_final_block = i == len(block_out_channels) - 1
139
-
140
- down_block = get_down_block(
141
- down_block_type,
142
- num_layers=layers_per_block,
143
- in_channels=input_channel,
144
- out_channels=output_channel,
145
- temb_channels=time_embed_dim,
146
- add_downsample=not is_final_block,
147
- resnet_eps=norm_eps,
148
- resnet_act_fn=act_fn,
149
- resnet_groups=norm_num_groups,
150
- cross_attention_dim=cross_attention_dim,
151
- attn_num_head_channels=attention_head_dim[i],
152
- downsample_padding=downsample_padding,
153
- dual_cross_attention=dual_cross_attention,
154
- use_linear_projection=use_linear_projection,
155
- only_cross_attention=only_cross_attention[i],
156
- upcast_attention=upcast_attention,
157
- resnet_time_scale_shift=resnet_time_scale_shift,
158
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
159
- unet_use_temporal_attention=unet_use_temporal_attention,
160
- use_inflated_groupnorm=use_inflated_groupnorm,
161
- use_motion_module=use_motion_module
162
- and (res in motion_module_resolutions)
163
- and (not motion_module_decoder_only),
164
- use_temporal_module=use_temporal_module
165
- and (res in motion_module_resolutions)
166
- and (not motion_module_decoder_only),
167
- motion_module_type=motion_module_type,
168
- temporal_module_type=temporal_module_type,
169
- motion_module_kwargs=motion_module_kwargs,
170
- temporal_module_kwargs=temporal_module_kwargs
171
- )
172
- self.down_blocks.append(down_block)
173
-
174
- # mid
175
- if mid_block_type == "UNetMidBlock3DCrossAttn":
176
- self.mid_block = UNetMidBlock3DCrossAttn(
177
- in_channels=block_out_channels[-1],
178
- temb_channels=time_embed_dim,
179
- resnet_eps=norm_eps,
180
- resnet_act_fn=act_fn,
181
- output_scale_factor=mid_block_scale_factor,
182
- resnet_time_scale_shift=resnet_time_scale_shift,
183
- cross_attention_dim=cross_attention_dim,
184
- attn_num_head_channels=attention_head_dim[-1],
185
- resnet_groups=norm_num_groups,
186
- dual_cross_attention=dual_cross_attention,
187
- use_linear_projection=use_linear_projection,
188
- upcast_attention=upcast_attention,
189
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
190
- unet_use_temporal_attention=unet_use_temporal_attention,
191
- use_inflated_groupnorm=use_inflated_groupnorm,
192
- use_motion_module=use_motion_module and motion_module_mid_block,
193
- use_temporal_module=use_temporal_module and motion_module_mid_block,
194
- motion_module_type=motion_module_type,
195
- temporal_module_type=temporal_module_type,
196
- motion_module_kwargs=motion_module_kwargs,
197
- temporal_module_kwargs=temporal_module_kwargs,
198
- )
199
- else:
200
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
201
-
202
- # count how many layers upsample the videos
203
- self.num_upsamplers = 0
204
-
205
- # up
206
- reversed_block_out_channels = list(reversed(block_out_channels))
207
- reversed_attention_head_dim = list(reversed(attention_head_dim))
208
- only_cross_attention = list(reversed(only_cross_attention))
209
- output_channel = reversed_block_out_channels[0]
210
- for i, up_block_type in enumerate(up_block_types):
211
- res = 2 ** (3 - i)
212
- is_final_block = i == len(block_out_channels) - 1
213
-
214
- prev_output_channel = output_channel
215
- output_channel = reversed_block_out_channels[i]
216
- input_channel = reversed_block_out_channels[
217
- min(i + 1, len(block_out_channels) - 1)
218
- ]
219
-
220
- # add upsample block for all BUT final layer
221
- if not is_final_block:
222
- add_upsample = True
223
- self.num_upsamplers += 1
224
- else:
225
- add_upsample = False
226
-
227
- up_block = get_up_block(
228
- up_block_type,
229
- num_layers=layers_per_block + 1,
230
- in_channels=input_channel,
231
- out_channels=output_channel,
232
- prev_output_channel=prev_output_channel,
233
- temb_channels=time_embed_dim,
234
- add_upsample=add_upsample,
235
- resnet_eps=norm_eps,
236
- resnet_act_fn=act_fn,
237
- resnet_groups=norm_num_groups,
238
- cross_attention_dim=cross_attention_dim,
239
- attn_num_head_channels=reversed_attention_head_dim[i],
240
- dual_cross_attention=dual_cross_attention,
241
- use_linear_projection=use_linear_projection,
242
- only_cross_attention=only_cross_attention[i],
243
- upcast_attention=upcast_attention,
244
- resnet_time_scale_shift=resnet_time_scale_shift,
245
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
246
- unet_use_temporal_attention=unet_use_temporal_attention,
247
- use_inflated_groupnorm=use_inflated_groupnorm,
248
- use_motion_module=use_motion_module
249
- and (res in motion_module_resolutions),
250
- use_temporal_module=use_temporal_module
251
- and (res in motion_module_resolutions),
252
- motion_module_type=motion_module_type,
253
- temporal_module_type=temporal_module_type,
254
- motion_module_kwargs=motion_module_kwargs,
255
- temporal_module_kwargs=temporal_module_kwargs,
256
- )
257
- self.up_blocks.append(up_block)
258
- prev_output_channel = output_channel
259
-
260
- # out
261
- if use_inflated_groupnorm:
262
- self.conv_norm_out = InflatedGroupNorm(
263
- num_channels=block_out_channels[0],
264
- num_groups=norm_num_groups,
265
- eps=norm_eps,
266
- )
267
- else:
268
- self.conv_norm_out = nn.GroupNorm(
269
- num_channels=block_out_channels[0],
270
- num_groups=norm_num_groups,
271
- eps=norm_eps,
272
- )
273
- self.conv_act = nn.SiLU()
274
- self.conv_out = InflatedConv3d(
275
- block_out_channels[0], out_channels, kernel_size=3, padding=1
276
- )
277
-
278
- @property
279
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
280
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
281
- r"""
282
- Returns:
283
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
284
- indexed by its weight name.
285
- """
286
- # set recursively
287
- processors = {}
288
-
289
- def fn_recursive_add_processors(
290
- name: str,
291
- module: torch.nn.Module,
292
- processors: Dict[str, AttentionProcessor],
293
- ):
294
- if hasattr(module, "set_processor"):
295
- processors[f"{name}.processor"] = module.processor
296
-
297
- for sub_name, child in module.named_children():
298
- if "temporal_transformer" not in sub_name:
299
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
300
-
301
- return processors
302
-
303
- for name, module in self.named_children():
304
- if "temporal_transformer" not in name:
305
- fn_recursive_add_processors(name, module, processors)
306
-
307
- return processors
308
-
309
- def set_attention_slice(self, slice_size):
310
- r"""
311
- Enable sliced attention computation.
312
-
313
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
314
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
315
-
316
- Args:
317
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
318
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
319
- `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
320
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
321
- must be a multiple of `slice_size`.
322
- """
323
- sliceable_head_dims = []
324
-
325
- def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
326
- if hasattr(module, "set_attention_slice"):
327
- sliceable_head_dims.append(module.sliceable_head_dim)
328
-
329
- for child in module.children():
330
- fn_recursive_retrieve_slicable_dims(child)
331
-
332
- # retrieve number of attention layers
333
- for module in self.children():
334
- fn_recursive_retrieve_slicable_dims(module)
335
-
336
- num_slicable_layers = len(sliceable_head_dims)
337
-
338
- if slice_size == "auto":
339
- # half the attention head size is usually a good trade-off between
340
- # speed and memory
341
- slice_size = [dim // 2 for dim in sliceable_head_dims]
342
- elif slice_size == "max":
343
- # make smallest slice possible
344
- slice_size = num_slicable_layers * [1]
345
-
346
- slice_size = (
347
- num_slicable_layers * [slice_size]
348
- if not isinstance(slice_size, list)
349
- else slice_size
350
- )
351
-
352
- if len(slice_size) != len(sliceable_head_dims):
353
- raise ValueError(
354
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
355
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
356
- )
357
-
358
- for i in range(len(slice_size)):
359
- size = slice_size[i]
360
- dim = sliceable_head_dims[i]
361
- if size is not None and size > dim:
362
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
363
-
364
- # Recursively walk through all the children.
365
- # Any children which exposes the set_attention_slice method
366
- # gets the message
367
- def fn_recursive_set_attention_slice(
368
- module: torch.nn.Module, slice_size: List[int]
369
- ):
370
- if hasattr(module, "set_attention_slice"):
371
- module.set_attention_slice(slice_size.pop())
372
-
373
- for child in module.children():
374
- fn_recursive_set_attention_slice(child, slice_size)
375
-
376
- reversed_slice_size = list(reversed(slice_size))
377
- for module in self.children():
378
- fn_recursive_set_attention_slice(module, reversed_slice_size)
379
-
380
- def set_use_cross_frame_attention(self, value):
381
-
382
- def fn_recursive_set_use_cf_att(module: torch.nn.Module, value):
383
- if hasattr(module, "set_use_cross_frame_attention"):
384
- module.set_use_cross_frame_attention(value)
385
-
386
- for child in module.children():
387
- fn_recursive_set_use_cf_att(child, value)
388
-
389
- for module in self.children():
390
- fn_recursive_set_use_cf_att(module, value)
391
-
392
- def _set_gradient_checkpointing(self, module, value=False):
393
- if hasattr(module, "gradient_checkpointing"):
394
- module.gradient_checkpointing = value
395
-
396
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
397
- def set_attn_processor(
398
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
399
- ):
400
- r"""
401
- Sets the attention processor to use to compute attention.
402
-
403
- Parameters:
404
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
405
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
406
- for **all** `Attention` layers.
407
-
408
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
409
- processor. This is strongly recommended when setting trainable attention processors.
410
-
411
- """
412
- count = len(self.attn_processors.keys())
413
-
414
- if isinstance(processor, dict) and len(processor) != count:
415
- raise ValueError(
416
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
417
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
418
- )
419
-
420
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
421
- if hasattr(module, "set_processor"):
422
- if not isinstance(processor, dict):
423
- module.set_processor(processor)
424
- else:
425
- module.set_processor(processor.pop(f"{name}.processor"))
426
-
427
- for sub_name, child in module.named_children():
428
- if "temporal_transformer" not in sub_name:
429
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
430
-
431
- for name, module in self.named_children():
432
- if "temporal_transformer" not in name:
433
- fn_recursive_attn_processor(name, module, processor)
434
-
435
- def forward(
436
- self,
437
- sample: torch.FloatTensor,
438
- timestep: Union[torch.Tensor, float, int],
439
- encoder_hidden_states: torch.Tensor,
440
- class_labels: Optional[torch.Tensor] = None,
441
- pose_cond_fea = None,
442
- attention_mask: Optional[torch.Tensor] = None,
443
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
444
- mid_block_additional_residual: Optional[torch.Tensor] = None,
445
- return_dict: bool = True,
446
- skip_mm: bool = False,
447
- ) -> Union[UNet3DConditionOutput, Tuple]:
448
- r"""
449
- Args:
450
- sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
451
- timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
452
- encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
453
- return_dict (`bool`, *optional*, defaults to `True`):
454
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
455
-
456
- Returns:
457
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
458
- [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
459
- returning a tuple, the first element is the sample tensor.
460
- """
461
- # By default samples have to be AT least a multiple of the overall upsampling factor.
462
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
463
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
464
- # on the fly if necessary.
465
-
466
- default_overall_up_factor = 2**self.num_upsamplers
467
-
468
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
469
- forward_upsample_size = False
470
- upsample_size = None
471
-
472
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
473
- logger.info("Forward upsample size to force interpolation output size.")
474
- forward_upsample_size = True
475
-
476
- # prepare attention_mask
477
- if attention_mask is not None:
478
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
479
- attention_mask = attention_mask.unsqueeze(1)
480
-
481
- # center input if necessary
482
- if self.config.center_input_sample:
483
- sample = 2 * sample - 1.0
484
-
485
- # time
486
- timesteps = timestep
487
- if not torch.is_tensor(timesteps):
488
- # This would be a good case for the `match` statement (Python 3.10+)
489
- is_mps = sample.device.type == "mps"
490
- if isinstance(timestep, float):
491
- dtype = torch.float32 if is_mps else torch.float64
492
- else:
493
- dtype = torch.int32 if is_mps else torch.int64
494
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
495
- elif len(timesteps.shape) == 0:
496
- timesteps = timesteps[None].to(sample.device)
497
-
498
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
499
- timesteps = timesteps.expand(sample.shape[0])
500
-
501
- t_emb = self.time_proj(timesteps)
502
-
503
- # timesteps does not contain any weights and will always return f32 tensors
504
- # but time_embedding might actually be running in fp16. so we need to cast here.
505
- # there might be better ways to encapsulate this.
506
- t_emb = t_emb.to(dtype=self.dtype)
507
- emb = self.time_embedding(t_emb)
508
- if self.class_embedding is not None:
509
- if class_labels is None:
510
- raise ValueError(
511
- "class_labels should be provided when num_class_embeds > 0"
512
- )
513
-
514
- if self.config.class_embed_type == "timestep":
515
- class_labels = self.time_proj(class_labels)
516
-
517
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
518
- emb = emb + class_emb
519
-
520
- # pre-process
521
- sample = self.conv_in(sample)
522
- if pose_cond_fea is not None:
523
- sample = sample + pose_cond_fea[0]
524
-
525
- # down
526
- down_block_res_samples = (sample,)
527
- block_count = 1
528
- for downsample_block in self.down_blocks:
529
- if (
530
- hasattr(downsample_block, "has_cross_attention")
531
- and downsample_block.has_cross_attention
532
- ):
533
- sample, res_samples = downsample_block(
534
- hidden_states=sample,
535
- temb=emb,
536
- encoder_hidden_states=encoder_hidden_states,
537
- attention_mask=attention_mask,
538
- skip_mm=skip_mm,
539
- )
540
- else:
541
- sample, res_samples = downsample_block(
542
- hidden_states=sample,
543
- temb=emb,
544
- encoder_hidden_states=encoder_hidden_states,
545
- skip_mm=skip_mm,
546
- )
547
- if pose_cond_fea is not None:
548
- sample = sample + pose_cond_fea[block_count]
549
- block_count += 1
550
- down_block_res_samples += res_samples
551
-
552
- if down_block_additional_residuals is not None:
553
- new_down_block_res_samples = ()
554
-
555
- for down_block_res_sample, down_block_additional_residual in zip(
556
- down_block_res_samples, down_block_additional_residuals
557
- ):
558
- down_block_res_sample = (
559
- down_block_res_sample + down_block_additional_residual
560
- )
561
- new_down_block_res_samples += (down_block_res_sample,)
562
-
563
- down_block_res_samples = new_down_block_res_samples
564
-
565
- # mid
566
- sample = self.mid_block(
567
- sample,
568
- emb,
569
- encoder_hidden_states=encoder_hidden_states,
570
- attention_mask=attention_mask,
571
- skip_mm=skip_mm,
572
- )
573
-
574
- if mid_block_additional_residual is not None:
575
- sample = sample + mid_block_additional_residual
576
-
577
- # up
578
- for i, upsample_block in enumerate(self.up_blocks):
579
- is_final_block = i == len(self.up_blocks) - 1
580
-
581
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
582
- down_block_res_samples = down_block_res_samples[
583
- : -len(upsample_block.resnets)
584
- ]
585
-
586
- # if we have not reached the final block and need to forward the
587
- # upsample size, we do it here
588
- if not is_final_block and forward_upsample_size:
589
- upsample_size = down_block_res_samples[-1].shape[2:]
590
-
591
- if (
592
- hasattr(upsample_block, "has_cross_attention")
593
- and upsample_block.has_cross_attention
594
- ):
595
- sample = upsample_block(
596
- hidden_states=sample,
597
- temb=emb,
598
- res_hidden_states_tuple=res_samples,
599
- encoder_hidden_states=encoder_hidden_states,
600
- upsample_size=upsample_size,
601
- attention_mask=attention_mask,
602
- skip_mm=skip_mm,
603
- )
604
- else:
605
- sample = upsample_block(
606
- hidden_states=sample,
607
- temb=emb,
608
- res_hidden_states_tuple=res_samples,
609
- upsample_size=upsample_size,
610
- encoder_hidden_states=encoder_hidden_states,
611
- skip_mm=skip_mm,
612
- )
613
-
614
- # post-process
615
- sample = self.conv_norm_out(sample)
616
- sample = self.conv_act(sample)
617
- sample = self.conv_out(sample)
618
-
619
- if not return_dict:
620
- return (sample,)
621
-
622
- return UNet3DConditionOutput(sample=sample)
623
-
624
- @classmethod
625
- def from_pretrained_2d(
626
- cls,
627
- pretrained_model_path: PathLike,
628
- motion_module_path: PathLike,
629
- subfolder=None,
630
- unet_additional_kwargs=None,
631
- mm_zero_proj_out=False,
632
- ):
633
- pretrained_model_path = Path(pretrained_model_path)
634
- motion_module_path = Path(motion_module_path)
635
- if subfolder is not None:
636
- pretrained_model_path = pretrained_model_path.joinpath(subfolder)
637
- logger.info(
638
- f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
639
- )
640
-
641
- config_file = pretrained_model_path / "config.json"
642
- if not (config_file.exists() and config_file.is_file()):
643
- raise RuntimeError(f"{config_file} does not exist or is not a file")
644
-
645
- unet_config = cls.load_config(config_file)
646
- unet_config["_class_name"] = cls.__name__
647
- unet_config["down_block_types"] = [
648
- "CrossAttnDownBlock3D",
649
- "CrossAttnDownBlock3D",
650
- "CrossAttnDownBlock3D",
651
- "DownBlock3D",
652
- ]
653
- unet_config["up_block_types"] = [
654
- "UpBlock3D",
655
- "CrossAttnUpBlock3D",
656
- "CrossAttnUpBlock3D",
657
- "CrossAttnUpBlock3D",
658
- ]
659
- unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
660
-
661
- model = cls.from_config(unet_config, **unet_additional_kwargs)
662
- # load the vanilla weights
663
- if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
664
- logger.debug(
665
- f"loading safeTensors weights from {pretrained_model_path} ..."
666
- )
667
- state_dict = load_file(
668
- pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
669
- )
670
-
671
- elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
672
- logger.debug(f"loading weights from {pretrained_model_path} ...")
673
- state_dict = torch.load(
674
- pretrained_model_path.joinpath(WEIGHTS_NAME),
675
- map_location="cpu",
676
- weights_only=True,
677
- )
678
- else:
679
- raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
680
-
681
- # load the motion module weights
682
- if motion_module_path.exists() and motion_module_path.is_file():
683
- if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
684
- logger.info(f"Load motion module params from {motion_module_path}")
685
- motion_state_dict = torch.load(
686
- motion_module_path, map_location="cpu", weights_only=True
687
- )
688
- elif motion_module_path.suffix.lower() == ".safetensors":
689
- motion_state_dict = load_file(motion_module_path, device="cpu")
690
- else:
691
- raise RuntimeError(
692
- f"unknown file format for motion module weights: {motion_module_path.suffix}"
693
- )
694
-
695
- motion_state_dict = {
696
- k.replace('motion_modules.', 'temporal_modules.'): v for k, v in motion_state_dict.items() if not "pos_encoder" in k
697
- }
698
-
699
- if mm_zero_proj_out:
700
- logger.info(f"Zero initialize proj_out layers in motion module...")
701
- new_motion_state_dict = OrderedDict()
702
- for k in motion_state_dict:
703
- if "proj_out" in k:
704
- continue
705
- new_motion_state_dict[k] = motion_state_dict[k]
706
- motion_state_dict = new_motion_state_dict
707
-
708
- # merge the state dicts
709
- state_dict.update(motion_state_dict)
710
-
711
- # load the weights into the model
712
- m, u = model.load_state_dict(state_dict, strict=False)
713
- logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
714
-
715
- params = [
716
- p.numel() if "temporal_modules" in n else 0
717
- for n, p in model.named_parameters()
718
- ]
719
- mm_params = [
720
- p.numel() if "motion_modules" in n else 0
721
- for n, p in model.named_parameters()
722
- ]
723
- logger.info(
724
- f"Loaded {sum(mm_params) / 1e6}M-parameter motion module, Loaded {sum(params) / 1e6}M-parameter temporal module"
725
- )
726
-
727
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/unet_3d_blocks.py DELETED
@@ -1,1121 +0,0 @@
1
- # *************************************************************************
2
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # This file has been modified by ByteDance Ltd. and/or its affiliates.
6
- #
7
- # Original file was released under Aniportrait, with the full license text
8
- # available at https://github.com/Zejun-Yang/AniPortrait/blob/main/LICENSE.
9
- #
10
- # This modified file is released under the same license.
11
- # *************************************************************************
12
- import pdb
13
- from typing import Dict, Optional
14
- import torch
15
- from torch import nn
16
-
17
- from src.models.motion_module import get_motion_module
18
-
19
- # from .motion_module import get_motion_module
20
- from src.models.resnet import Downsample3D, ResnetBlock3D, Upsample3D
21
- from .transformer_3d import Transformer3DModel
22
-
23
-
24
- def get_down_block(
25
- down_block_type,
26
- num_layers,
27
- in_channels,
28
- out_channels,
29
- temb_channels,
30
- add_downsample,
31
- resnet_eps,
32
- resnet_act_fn,
33
- attn_num_head_channels,
34
- resnet_groups=None,
35
- cross_attention_dim=None,
36
- downsample_padding=None,
37
- dual_cross_attention=False,
38
- use_linear_projection=False,
39
- only_cross_attention=False,
40
- upcast_attention=False,
41
- resnet_time_scale_shift="default",
42
- unet_use_cross_frame_attention=None,
43
- unet_use_temporal_attention=None,
44
- use_inflated_groupnorm=None,
45
- use_motion_module=None,
46
- motion_module_type=None,
47
- motion_module_kwargs=None,
48
- use_temporal_module=None,
49
- temporal_module_type=None,
50
- temporal_module_kwargs=None,
51
- ):
52
- down_block_type = (
53
- down_block_type[7:]
54
- if down_block_type.startswith("UNetRes")
55
- else down_block_type
56
- )
57
- if down_block_type == "DownBlock3D":
58
- return DownBlock3D(
59
- num_layers=num_layers,
60
- in_channels=in_channels,
61
- out_channels=out_channels,
62
- temb_channels=temb_channels,
63
- add_downsample=add_downsample,
64
- resnet_eps=resnet_eps,
65
- resnet_act_fn=resnet_act_fn,
66
- resnet_groups=resnet_groups,
67
- downsample_padding=downsample_padding,
68
- resnet_time_scale_shift=resnet_time_scale_shift,
69
- use_inflated_groupnorm=use_inflated_groupnorm,
70
- use_motion_module=use_motion_module,
71
- motion_module_type=motion_module_type,
72
- motion_module_kwargs=motion_module_kwargs,
73
- use_temporal_module=use_temporal_module,
74
- temporal_module_type=temporal_module_type,
75
- temporal_module_kwargs=temporal_module_kwargs,
76
- )
77
- elif down_block_type == "CrossAttnDownBlock3D":
78
- if cross_attention_dim is None:
79
- raise ValueError(
80
- "cross_attention_dim must be specified for CrossAttnDownBlock3D"
81
- )
82
- return CrossAttnDownBlock3D(
83
- num_layers=num_layers,
84
- in_channels=in_channels,
85
- out_channels=out_channels,
86
- temb_channels=temb_channels,
87
- add_downsample=add_downsample,
88
- resnet_eps=resnet_eps,
89
- resnet_act_fn=resnet_act_fn,
90
- resnet_groups=resnet_groups,
91
- downsample_padding=downsample_padding,
92
- cross_attention_dim=cross_attention_dim,
93
- attn_num_head_channels=attn_num_head_channels,
94
- dual_cross_attention=dual_cross_attention,
95
- use_linear_projection=use_linear_projection,
96
- only_cross_attention=only_cross_attention,
97
- upcast_attention=upcast_attention,
98
- resnet_time_scale_shift=resnet_time_scale_shift,
99
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
100
- unet_use_temporal_attention=unet_use_temporal_attention,
101
- use_inflated_groupnorm=use_inflated_groupnorm,
102
- use_motion_module=use_motion_module,
103
- motion_module_type=motion_module_type,
104
- motion_module_kwargs=motion_module_kwargs,
105
- use_temporal_module=use_temporal_module,
106
- temporal_module_type=temporal_module_type,
107
- temporal_module_kwargs=temporal_module_kwargs,
108
- )
109
- raise ValueError(f"{down_block_type} does not exist.")
110
-
111
-
112
- def get_up_block(
113
- up_block_type,
114
- num_layers,
115
- in_channels,
116
- out_channels,
117
- prev_output_channel,
118
- temb_channels,
119
- add_upsample,
120
- resnet_eps,
121
- resnet_act_fn,
122
- attn_num_head_channels,
123
- resnet_groups=None,
124
- cross_attention_dim=None,
125
- dual_cross_attention=False,
126
- use_linear_projection=False,
127
- only_cross_attention=False,
128
- upcast_attention=False,
129
- resnet_time_scale_shift="default",
130
- unet_use_cross_frame_attention=None,
131
- unet_use_temporal_attention=None,
132
- use_inflated_groupnorm=None,
133
- use_motion_module=None,
134
- motion_module_type=None,
135
- motion_module_kwargs=None,
136
- use_temporal_module=None,
137
- temporal_module_type=None,
138
- temporal_module_kwargs=None,
139
- ):
140
- up_block_type = (
141
- up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
142
- )
143
- if up_block_type == "UpBlock3D":
144
- return UpBlock3D(
145
- num_layers=num_layers,
146
- in_channels=in_channels,
147
- out_channels=out_channels,
148
- prev_output_channel=prev_output_channel,
149
- temb_channels=temb_channels,
150
- add_upsample=add_upsample,
151
- resnet_eps=resnet_eps,
152
- resnet_act_fn=resnet_act_fn,
153
- resnet_groups=resnet_groups,
154
- resnet_time_scale_shift=resnet_time_scale_shift,
155
- use_inflated_groupnorm=use_inflated_groupnorm,
156
- use_motion_module=use_motion_module,
157
- motion_module_type=motion_module_type,
158
- motion_module_kwargs=motion_module_kwargs,
159
- use_temporal_module=use_temporal_module,
160
- temporal_module_type=temporal_module_type,
161
- temporal_module_kwargs=temporal_module_kwargs,
162
- )
163
- elif up_block_type == "CrossAttnUpBlock3D":
164
- if cross_attention_dim is None:
165
- raise ValueError(
166
- "cross_attention_dim must be specified for CrossAttnUpBlock3D"
167
- )
168
- return CrossAttnUpBlock3D(
169
- num_layers=num_layers,
170
- in_channels=in_channels,
171
- out_channels=out_channels,
172
- prev_output_channel=prev_output_channel,
173
- temb_channels=temb_channels,
174
- add_upsample=add_upsample,
175
- resnet_eps=resnet_eps,
176
- resnet_act_fn=resnet_act_fn,
177
- resnet_groups=resnet_groups,
178
- cross_attention_dim=cross_attention_dim,
179
- attn_num_head_channels=attn_num_head_channels,
180
- dual_cross_attention=dual_cross_attention,
181
- use_linear_projection=use_linear_projection,
182
- only_cross_attention=only_cross_attention,
183
- upcast_attention=upcast_attention,
184
- resnet_time_scale_shift=resnet_time_scale_shift,
185
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
186
- unet_use_temporal_attention=unet_use_temporal_attention,
187
- use_inflated_groupnorm=use_inflated_groupnorm,
188
- use_motion_module=use_motion_module,
189
- motion_module_type=motion_module_type,
190
- motion_module_kwargs=motion_module_kwargs,
191
- use_temporal_module=use_temporal_module,
192
- temporal_module_type=temporal_module_type,
193
- temporal_module_kwargs=temporal_module_kwargs,
194
- )
195
- raise ValueError(f"{up_block_type} does not exist.")
196
-
197
-
198
- class UNetMidBlock3DCrossAttn(nn.Module):
199
-
200
- def __init__(
201
- self,
202
- in_channels: int,
203
- temb_channels: int,
204
- dropout: float = 0.0,
205
- num_layers: int = 1,
206
- resnet_eps: float = 1e-6,
207
- resnet_time_scale_shift: str = "default",
208
- resnet_act_fn: str = "swish",
209
- resnet_groups: int = 32,
210
- resnet_pre_norm: bool = True,
211
- attn_num_head_channels=1,
212
- output_scale_factor=1.0,
213
- cross_attention_dim=1280,
214
- dual_cross_attention=False,
215
- use_linear_projection=False,
216
- upcast_attention=False,
217
- unet_use_cross_frame_attention=None,
218
- unet_use_temporal_attention=None,
219
- use_inflated_groupnorm=None,
220
- use_motion_module=None,
221
- motion_module_type=None,
222
- motion_module_kwargs=None,
223
- use_temporal_module=None,
224
- temporal_module_type=None,
225
- temporal_module_kwargs=None,
226
- **transformer_kwargs,
227
- ):
228
- super().__init__()
229
-
230
- self.has_cross_attention = True
231
- self.attn_num_head_channels = attn_num_head_channels
232
- resnet_groups = (
233
- resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
234
- )
235
-
236
- # there is always at least one resnet
237
- resnets = [
238
- ResnetBlock3D(
239
- in_channels=in_channels,
240
- out_channels=in_channels,
241
- temb_channels=temb_channels,
242
- eps=resnet_eps,
243
- groups=resnet_groups,
244
- dropout=dropout,
245
- time_embedding_norm=resnet_time_scale_shift,
246
- non_linearity=resnet_act_fn,
247
- output_scale_factor=output_scale_factor,
248
- pre_norm=resnet_pre_norm,
249
- use_inflated_groupnorm=use_inflated_groupnorm,
250
- )
251
- ]
252
- attentions = []
253
- motion_modules = []
254
-
255
- for _ in range(num_layers):
256
- if dual_cross_attention:
257
- raise NotImplementedError
258
- attentions.append(
259
- Transformer3DModel(
260
- attn_num_head_channels,
261
- in_channels // attn_num_head_channels,
262
- in_channels=in_channels,
263
- num_layers=1,
264
- cross_attention_dim=cross_attention_dim,
265
- norm_num_groups=resnet_groups,
266
- use_linear_projection=use_linear_projection,
267
- upcast_attention=upcast_attention,
268
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
269
- unet_use_temporal_attention=unet_use_temporal_attention,
270
- **transformer_kwargs
271
- )
272
- )
273
- motion_modules.append(
274
- get_motion_module(
275
- in_channels=in_channels,
276
- motion_module_type=motion_module_type,
277
- motion_module_kwargs=motion_module_kwargs,
278
- )
279
- if use_motion_module
280
- else None
281
- )
282
- resnets.append(
283
- ResnetBlock3D(
284
- in_channels=in_channels,
285
- out_channels=in_channels,
286
- temb_channels=temb_channels,
287
- eps=resnet_eps,
288
- groups=resnet_groups,
289
- dropout=dropout,
290
- time_embedding_norm=resnet_time_scale_shift,
291
- non_linearity=resnet_act_fn,
292
- output_scale_factor=output_scale_factor,
293
- pre_norm=resnet_pre_norm,
294
- use_inflated_groupnorm=use_inflated_groupnorm,
295
- )
296
- )
297
-
298
- self.attentions = nn.ModuleList(attentions)
299
- self.resnets = nn.ModuleList(resnets)
300
- self.motion_modules = nn.ModuleList(motion_modules)
301
- self.temporal_modules = nn.ModuleList(
302
- [
303
- (
304
- get_motion_module(
305
- in_channels=in_channels,
306
- motion_module_type=temporal_module_type,
307
- motion_module_kwargs=temporal_module_kwargs,
308
- )
309
- if use_temporal_module
310
- else None
311
- )
312
- for _ in range(num_layers)
313
- ]
314
- )
315
- self.gradient_checkpointing = False
316
-
317
- def forward(
318
- self,
319
- hidden_states,
320
- temb=None,
321
- encoder_hidden_states=None,
322
- attention_mask=None,
323
- skip_mm=False,
324
- ):
325
- if isinstance(encoder_hidden_states, list):
326
- encoder_hidden_states, motion_hidden_states = encoder_hidden_states
327
- else:
328
- motion_hidden_states = encoder_hidden_states
329
-
330
- hidden_states = self.resnets[0](hidden_states, temb)
331
- for attn, resnet, motion_module, temporal_module in zip(
332
- self.attentions, self.resnets[1:], self.motion_modules, self.temporal_modules
333
- ):
334
- if self.training and self.gradient_checkpointing:
335
- def create_custom_forward(module, return_dict=None):
336
- def custom_forward(*inputs):
337
- if return_dict is not None:
338
- return module(*inputs, return_dict=return_dict)
339
- else:
340
- return module(*inputs)
341
-
342
- return custom_forward
343
- hidden_states = torch.utils.checkpoint.checkpoint(
344
- create_custom_forward(attn, return_dict=False),
345
- hidden_states,
346
- encoder_hidden_states,
347
- )[0]
348
- if (motion_module is not None) and not skip_mm:
349
- hidden_states = torch.utils.checkpoint.checkpoint(
350
- create_custom_forward(motion_module),
351
- hidden_states,
352
- temb,
353
- motion_hidden_states,
354
- )
355
- if (temporal_module is not None) and not skip_mm:
356
- hidden_states = torch.utils.checkpoint.checkpoint(
357
- create_custom_forward(temporal_module),
358
- hidden_states.requires_grad_(),
359
- temb,
360
- None,
361
- )
362
- # hidden_states = (
363
- # temporal_module(hidden_states, temb, encoder_hidden_states=None)
364
- # if (temporal_module is not None) and not skip_mm
365
- # else hidden_states
366
- # )
367
- hidden_states = torch.utils.checkpoint.checkpoint(
368
- create_custom_forward(resnet), hidden_states, temb
369
- )
370
- else:
371
- hidden_states = attn(
372
- hidden_states,
373
- encoder_hidden_states=encoder_hidden_states,
374
- ).sample
375
- hidden_states = (
376
- motion_module(
377
- hidden_states, temb, encoder_hidden_states=motion_hidden_states
378
- )
379
- if (motion_module is not None) and not skip_mm
380
- else hidden_states
381
- )
382
- hidden_states = (
383
- temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
384
- if (temporal_module is not None) and not skip_mm
385
- else hidden_states
386
- )
387
-
388
- hidden_states = resnet(hidden_states, temb)
389
-
390
- return hidden_states
391
-
392
-
393
- class CrossAttnDownBlock3D(nn.Module):
394
-
395
- def __init__(
396
- self,
397
- in_channels: int,
398
- out_channels: int,
399
- temb_channels: int,
400
- dropout: float = 0.0,
401
- num_layers: int = 1,
402
- resnet_eps: float = 1e-6,
403
- resnet_time_scale_shift: str = "default",
404
- resnet_act_fn: str = "swish",
405
- resnet_groups: int = 32,
406
- resnet_pre_norm: bool = True,
407
- attn_num_head_channels=1,
408
- cross_attention_dim=1280,
409
- output_scale_factor=1.0,
410
- downsample_padding=1,
411
- add_downsample=True,
412
- dual_cross_attention=False,
413
- use_linear_projection=False,
414
- only_cross_attention=False,
415
- upcast_attention=False,
416
- unet_use_cross_frame_attention=None,
417
- unet_use_temporal_attention=None,
418
- use_inflated_groupnorm=None,
419
- use_motion_module=None,
420
- motion_module_type=None,
421
- motion_module_kwargs=None,
422
- use_temporal_module=None,
423
- temporal_module_type=None,
424
- temporal_module_kwargs=None,
425
- **transformer_kwargs,
426
- ):
427
- super().__init__()
428
- resnets = []
429
- attentions = []
430
- motion_modules = []
431
-
432
- self.has_cross_attention = True
433
- self.attn_num_head_channels = attn_num_head_channels
434
-
435
- for i in range(num_layers):
436
- in_channels = in_channels if i == 0 else out_channels
437
- resnets.append(
438
- ResnetBlock3D(
439
- in_channels=in_channels,
440
- out_channels=out_channels,
441
- temb_channels=temb_channels,
442
- eps=resnet_eps,
443
- groups=resnet_groups,
444
- dropout=dropout,
445
- time_embedding_norm=resnet_time_scale_shift,
446
- non_linearity=resnet_act_fn,
447
- output_scale_factor=output_scale_factor,
448
- pre_norm=resnet_pre_norm,
449
- use_inflated_groupnorm=use_inflated_groupnorm,
450
- )
451
- )
452
- if dual_cross_attention:
453
- raise NotImplementedError
454
- attentions.append(
455
- Transformer3DModel(
456
- attn_num_head_channels,
457
- out_channels // attn_num_head_channels,
458
- in_channels=out_channels,
459
- num_layers=1,
460
- cross_attention_dim=cross_attention_dim,
461
- norm_num_groups=resnet_groups,
462
- use_linear_projection=use_linear_projection,
463
- only_cross_attention=only_cross_attention,
464
- upcast_attention=upcast_attention,
465
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
466
- unet_use_temporal_attention=unet_use_temporal_attention,
467
- **transformer_kwargs,
468
- )
469
- )
470
- motion_modules.append(
471
- get_motion_module(
472
- in_channels=out_channels,
473
- motion_module_type=motion_module_type,
474
- motion_module_kwargs=motion_module_kwargs,
475
- )
476
- if use_motion_module
477
- else None
478
- )
479
-
480
- self.attentions = nn.ModuleList(attentions)
481
- self.resnets = nn.ModuleList(resnets)
482
- self.motion_modules = nn.ModuleList(motion_modules)
483
- self.temporal_modules = nn.ModuleList(
484
- [
485
- (
486
- get_motion_module(
487
- in_channels=out_channels,
488
- motion_module_type=temporal_module_type,
489
- motion_module_kwargs=temporal_module_kwargs,
490
- )
491
- if use_temporal_module
492
- else None
493
- )
494
- for _ in range(num_layers)
495
- ]
496
- )
497
-
498
- if add_downsample:
499
- self.downsamplers = nn.ModuleList(
500
- [
501
- Downsample3D(
502
- out_channels,
503
- use_conv=True,
504
- out_channels=out_channels,
505
- padding=downsample_padding,
506
- name="op",
507
- )
508
- ]
509
- )
510
- else:
511
- self.downsamplers = None
512
-
513
- self.gradient_checkpointing = False
514
-
515
- def forward(
516
- self,
517
- hidden_states,
518
- temb=None,
519
- encoder_hidden_states=None,
520
- attention_mask=None,
521
- skip_mm=False
522
- ):
523
- if isinstance(encoder_hidden_states, list):
524
- encoder_hidden_states, motion_hidden_states = encoder_hidden_states
525
- else:
526
- motion_hidden_states = encoder_hidden_states
527
-
528
- output_states = ()
529
-
530
- for i, (resnet, attn, motion_module, temporal_module) in enumerate(
531
- zip(self.resnets, self.attentions, self.motion_modules, self.temporal_modules)
532
- ):
533
-
534
- # self.gradient_checkpointing = False
535
- if self.training and self.gradient_checkpointing:
536
-
537
- def create_custom_forward(module, return_dict=None):
538
- def custom_forward(*inputs):
539
- if return_dict is not None:
540
- return module(*inputs, return_dict=return_dict)
541
- else:
542
- return module(*inputs)
543
-
544
- return custom_forward
545
-
546
- hidden_states = torch.utils.checkpoint.checkpoint(
547
- create_custom_forward(resnet), hidden_states, temb
548
- )
549
- hidden_states = torch.utils.checkpoint.checkpoint(
550
- create_custom_forward(attn, return_dict=False),
551
- hidden_states,
552
- encoder_hidden_states,
553
- )[0]
554
-
555
- # add motion module
556
- if (motion_module is not None) and not skip_mm:
557
- hidden_states = torch.utils.checkpoint.checkpoint(
558
- create_custom_forward(motion_module),
559
- hidden_states,
560
- temb,
561
- motion_hidden_states,
562
- )
563
- if (temporal_module is not None) and not skip_mm:
564
- # hidden_states = torch.utils.checkpoint.checkpoint(
565
- # create_custom_forward(temporal_module),
566
- # hidden_states.requires_grad_(),
567
- # temb,
568
- # None,
569
- # )
570
- hidden_states = (
571
- temporal_module(hidden_states, temb, encoder_hidden_states=None)
572
- if (temporal_module is not None) and not skip_mm
573
- else hidden_states
574
- )
575
-
576
- else:
577
- hidden_states = resnet(hidden_states, temb)
578
- hidden_states = attn(
579
- hidden_states,
580
- encoder_hidden_states=encoder_hidden_states,
581
- ).sample
582
-
583
- # add motion module
584
- hidden_states = (
585
- motion_module(
586
- hidden_states, temb, encoder_hidden_states=motion_hidden_states
587
- )
588
- if (motion_module is not None) and not skip_mm
589
- else hidden_states
590
- )
591
- hidden_states = (
592
- temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
593
- if (temporal_module is not None) and not skip_mm
594
- else hidden_states
595
- )
596
-
597
- output_states += (hidden_states,)
598
-
599
- if self.downsamplers is not None:
600
- for downsampler in self.downsamplers:
601
- hidden_states = downsampler(hidden_states)
602
-
603
- output_states += (hidden_states,)
604
-
605
- return hidden_states, output_states
606
-
607
-
608
- class DownBlock3D(nn.Module):
609
-
610
- def __init__(
611
- self,
612
- in_channels: int,
613
- out_channels: int,
614
- temb_channels: int,
615
- dropout: float = 0.0,
616
- num_layers: int = 1,
617
- resnet_eps: float = 1e-6,
618
- resnet_time_scale_shift: str = "default",
619
- resnet_act_fn: str = "swish",
620
- resnet_groups: int = 32,
621
- resnet_pre_norm: bool = True,
622
- output_scale_factor=1.0,
623
- add_downsample=True,
624
- downsample_padding=1,
625
- use_inflated_groupnorm=None,
626
- use_motion_module=None,
627
- motion_module_type=None,
628
- motion_module_kwargs=None,
629
- use_temporal_module=None,
630
- temporal_module_type=None,
631
- temporal_module_kwargs=None,
632
- ):
633
- super().__init__()
634
- resnets = []
635
- motion_modules = []
636
-
637
- # use_motion_module = False
638
- for i in range(num_layers):
639
- in_channels = in_channels if i == 0 else out_channels
640
- resnets.append(
641
- ResnetBlock3D(
642
- in_channels=in_channels,
643
- out_channels=out_channels,
644
- temb_channels=temb_channels,
645
- eps=resnet_eps,
646
- groups=resnet_groups,
647
- dropout=dropout,
648
- time_embedding_norm=resnet_time_scale_shift,
649
- non_linearity=resnet_act_fn,
650
- output_scale_factor=output_scale_factor,
651
- pre_norm=resnet_pre_norm,
652
- use_inflated_groupnorm=use_inflated_groupnorm,
653
- )
654
- )
655
- motion_modules.append(
656
- get_motion_module(
657
- in_channels=out_channels,
658
- motion_module_type=motion_module_type,
659
- motion_module_kwargs=motion_module_kwargs,
660
- )
661
- if use_motion_module
662
- else None
663
- )
664
-
665
- self.resnets = nn.ModuleList(resnets)
666
- self.motion_modules = nn.ModuleList(motion_modules)
667
- self.temporal_modules = nn.ModuleList(
668
- [
669
- (
670
- get_motion_module(
671
- in_channels=out_channels,
672
- motion_module_type=temporal_module_type,
673
- motion_module_kwargs=temporal_module_kwargs,
674
- )
675
- if use_temporal_module
676
- else None
677
- )
678
- for _ in range(num_layers)
679
- ]
680
- )
681
-
682
- if add_downsample:
683
- self.downsamplers = nn.ModuleList(
684
- [
685
- Downsample3D(
686
- out_channels,
687
- use_conv=True,
688
- out_channels=out_channels,
689
- padding=downsample_padding,
690
- name="op",
691
- )
692
- ]
693
- )
694
- else:
695
- self.downsamplers = None
696
-
697
- self.gradient_checkpointing = False
698
-
699
- def forward(self, hidden_states, temb=None, encoder_hidden_states=None, skip_mm=False):
700
- output_states = ()
701
- if isinstance(encoder_hidden_states, list):
702
- encoder_hidden_states, motion_hidden_states = encoder_hidden_states
703
- else:
704
- motion_hidden_states = encoder_hidden_states
705
- for resnet, motion_module, temporal_module in zip(
706
- self.resnets, self.motion_modules, self.temporal_modules
707
- ):
708
- # print(f"DownBlock3D {self.gradient_checkpointing = }")
709
- if self.training and self.gradient_checkpointing:
710
-
711
- def create_custom_forward(module):
712
- def custom_forward(*inputs):
713
- return module(*inputs)
714
-
715
- return custom_forward
716
-
717
- hidden_states = torch.utils.checkpoint.checkpoint(
718
- create_custom_forward(resnet), hidden_states, temb
719
- )
720
- if (motion_module is not None) and not skip_mm:
721
- hidden_states = torch.utils.checkpoint.checkpoint(
722
- create_custom_forward(motion_module),
723
- hidden_states,
724
- temb,
725
- motion_hidden_states,
726
- )
727
-
728
- if (temporal_module is not None) and not skip_mm:
729
- hidden_states = torch.utils.checkpoint.checkpoint(
730
- create_custom_forward(temporal_module),
731
- hidden_states.requires_grad_(),
732
- temb,
733
- None,
734
- )
735
- else:
736
- hidden_states = resnet(hidden_states, temb)
737
-
738
- # add motion module
739
- hidden_states = (
740
- motion_module(
741
- hidden_states, temb, encoder_hidden_states=motion_hidden_states
742
- )
743
- if (motion_module is not None) and not skip_mm
744
- else hidden_states
745
- )
746
- hidden_states = (
747
- temporal_module(
748
- hidden_states, temb, encoder_hidden_states=None, debug=True
749
- )
750
- if (temporal_module is not None) and not skip_mm
751
- else hidden_states
752
- )
753
-
754
- output_states += (hidden_states,)
755
-
756
- if self.downsamplers is not None:
757
- for downsampler in self.downsamplers:
758
- hidden_states = downsampler(hidden_states)
759
-
760
- output_states += (hidden_states,)
761
-
762
- return hidden_states, output_states
763
-
764
-
765
- class CrossAttnUpBlock3D(nn.Module):
766
-
767
- def __init__(
768
- self,
769
- in_channels: int,
770
- out_channels: int,
771
- prev_output_channel: int,
772
- temb_channels: int,
773
- dropout: float = 0.0,
774
- num_layers: int = 1,
775
- resnet_eps: float = 1e-6,
776
- resnet_time_scale_shift: str = "default",
777
- resnet_act_fn: str = "swish",
778
- resnet_groups: int = 32,
779
- resnet_pre_norm: bool = True,
780
- attn_num_head_channels=1,
781
- cross_attention_dim=1280,
782
- output_scale_factor=1.0,
783
- add_upsample=True,
784
- dual_cross_attention=False,
785
- use_linear_projection=False,
786
- only_cross_attention=False,
787
- upcast_attention=False,
788
- unet_use_cross_frame_attention=None,
789
- unet_use_temporal_attention=None,
790
- use_motion_module=None,
791
- use_inflated_groupnorm=None,
792
- motion_module_type=None,
793
- motion_module_kwargs=None,
794
- use_temporal_module=None,
795
- temporal_module_type=None,
796
- temporal_module_kwargs=None,
797
- **transformer_kwargs,
798
- ):
799
- super().__init__()
800
- resnets = []
801
- attentions = []
802
- motion_modules = []
803
-
804
- self.has_cross_attention = True
805
- self.attn_num_head_channels = attn_num_head_channels
806
-
807
- for i in range(num_layers):
808
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
809
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
810
-
811
- resnets.append(
812
- ResnetBlock3D(
813
- in_channels=resnet_in_channels + res_skip_channels,
814
- out_channels=out_channels,
815
- temb_channels=temb_channels,
816
- eps=resnet_eps,
817
- groups=resnet_groups,
818
- dropout=dropout,
819
- time_embedding_norm=resnet_time_scale_shift,
820
- non_linearity=resnet_act_fn,
821
- output_scale_factor=output_scale_factor,
822
- pre_norm=resnet_pre_norm,
823
- use_inflated_groupnorm=use_inflated_groupnorm,
824
- )
825
- )
826
- if dual_cross_attention:
827
- raise NotImplementedError
828
- attentions.append(
829
- Transformer3DModel(
830
- attn_num_head_channels,
831
- out_channels // attn_num_head_channels,
832
- in_channels=out_channels,
833
- num_layers=1,
834
- cross_attention_dim=cross_attention_dim,
835
- norm_num_groups=resnet_groups,
836
- use_linear_projection=use_linear_projection,
837
- only_cross_attention=only_cross_attention,
838
- upcast_attention=upcast_attention,
839
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
840
- unet_use_temporal_attention=unet_use_temporal_attention,
841
- **transformer_kwargs,
842
- )
843
- )
844
- motion_modules.append(
845
- get_motion_module(
846
- in_channels=out_channels,
847
- motion_module_type=motion_module_type,
848
- motion_module_kwargs=motion_module_kwargs,
849
- )
850
- if use_motion_module
851
- else None
852
- )
853
-
854
- self.attentions = nn.ModuleList(attentions)
855
- self.resnets = nn.ModuleList(resnets)
856
- self.motion_modules = nn.ModuleList(motion_modules)
857
- self.temporal_modules = nn.ModuleList(
858
- [
859
- (
860
- get_motion_module(
861
- in_channels=out_channels,
862
- motion_module_type=temporal_module_type,
863
- motion_module_kwargs=temporal_module_kwargs,
864
- )
865
- if use_temporal_module
866
- else None
867
- )
868
- for _ in range(num_layers)
869
- ]
870
- )
871
-
872
- if add_upsample:
873
- self.upsamplers = nn.ModuleList(
874
- [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
875
- )
876
- else:
877
- self.upsamplers = None
878
-
879
- self.gradient_checkpointing = False
880
-
881
- def forward(
882
- self,
883
- hidden_states,
884
- res_hidden_states_tuple,
885
- temb=None,
886
- encoder_hidden_states=None,
887
- upsample_size=None,
888
- attention_mask=None,
889
- skip_mm=False,
890
- ):
891
- if isinstance(encoder_hidden_states, list):
892
- encoder_hidden_states, motion_hidden_states = encoder_hidden_states
893
- else:
894
- motion_hidden_states = encoder_hidden_states
895
- for i, (resnet, attn, motion_module, temporal_module) in enumerate(
896
- zip(self.resnets, self.attentions, self.motion_modules, self.temporal_modules)
897
- ):
898
- # pop res hidden states
899
- res_hidden_states = res_hidden_states_tuple[-1]
900
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
901
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
902
-
903
- if self.training and self.gradient_checkpointing:
904
-
905
- def create_custom_forward(module, return_dict=None):
906
- def custom_forward(*inputs):
907
- if return_dict is not None:
908
- return module(*inputs, return_dict=return_dict)
909
- else:
910
- return module(*inputs)
911
-
912
- return custom_forward
913
-
914
- hidden_states = torch.utils.checkpoint.checkpoint(
915
- create_custom_forward(resnet), hidden_states, temb
916
- )
917
- # hidden_states = attn(
918
- # hidden_states,
919
- # encoder_hidden_states=encoder_hidden_states,
920
- # ).sample
921
- hidden_states = torch.utils.checkpoint.checkpoint(
922
- create_custom_forward(attn, return_dict=False),
923
- hidden_states,
924
- encoder_hidden_states,
925
- )[0]
926
- if (motion_module is not None) and not skip_mm:
927
- hidden_states = torch.utils.checkpoint.checkpoint(
928
- create_custom_forward(motion_module),
929
- hidden_states,
930
- temb,
931
- motion_hidden_states,
932
- )
933
- if (temporal_module is not None) and not skip_mm:
934
- hidden_states = torch.utils.checkpoint.checkpoint(
935
- create_custom_forward(temporal_module),
936
- hidden_states.requires_grad_(),
937
- temb,
938
- None,
939
- )
940
-
941
- else:
942
- hidden_states = resnet(hidden_states, temb)
943
- hidden_states = attn(
944
- hidden_states,
945
- encoder_hidden_states=encoder_hidden_states,
946
- ).sample
947
-
948
- # add motion module
949
- hidden_states = (
950
- motion_module(
951
- hidden_states, temb, encoder_hidden_states=motion_hidden_states
952
- )
953
- if (motion_module is not None) and not skip_mm
954
- else hidden_states
955
- )
956
-
957
- # add temporal_module
958
- hidden_states = (
959
- temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
960
- if (temporal_module is not None) and not skip_mm
961
- else hidden_states
962
- )
963
-
964
- if self.upsamplers is not None:
965
- for upsampler in self.upsamplers:
966
- hidden_states = upsampler(hidden_states, upsample_size)
967
-
968
- return hidden_states
969
-
970
-
971
- class UpBlock3D(nn.Module):
972
-
973
- def __init__(
974
- self,
975
- in_channels: int,
976
- prev_output_channel: int,
977
- out_channels: int,
978
- temb_channels: int,
979
- dropout: float = 0.0,
980
- num_layers: int = 1,
981
- resnet_eps: float = 1e-6,
982
- resnet_time_scale_shift: str = "default",
983
- resnet_act_fn: str = "swish",
984
- resnet_groups: int = 32,
985
- resnet_pre_norm: bool = True,
986
- output_scale_factor=1.0,
987
- add_upsample=True,
988
- use_inflated_groupnorm=None,
989
- use_motion_module=None,
990
- motion_module_type=None,
991
- motion_module_kwargs=None,
992
- use_temporal_module=None,
993
- temporal_module_type=None,
994
- temporal_module_kwargs=None,
995
- ):
996
- super().__init__()
997
- resnets = []
998
- motion_modules = []
999
-
1000
- # use_motion_module = False
1001
- for i in range(num_layers):
1002
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1003
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
1004
-
1005
- resnets.append(
1006
- ResnetBlock3D(
1007
- in_channels=resnet_in_channels + res_skip_channels,
1008
- out_channels=out_channels,
1009
- temb_channels=temb_channels,
1010
- eps=resnet_eps,
1011
- groups=resnet_groups,
1012
- dropout=dropout,
1013
- time_embedding_norm=resnet_time_scale_shift,
1014
- non_linearity=resnet_act_fn,
1015
- output_scale_factor=output_scale_factor,
1016
- pre_norm=resnet_pre_norm,
1017
- use_inflated_groupnorm=use_inflated_groupnorm,
1018
- )
1019
- )
1020
- motion_modules.append(
1021
- get_motion_module(
1022
- in_channels=out_channels,
1023
- motion_module_type=motion_module_type,
1024
- motion_module_kwargs=motion_module_kwargs,
1025
- )
1026
- if use_motion_module
1027
- else None
1028
- )
1029
-
1030
- self.resnets = nn.ModuleList(resnets)
1031
- self.motion_modules = nn.ModuleList(motion_modules)
1032
- self.temporal_modules = nn.ModuleList(
1033
- [
1034
- (
1035
- get_motion_module(
1036
- in_channels=out_channels,
1037
- motion_module_type=temporal_module_type,
1038
- motion_module_kwargs=temporal_module_kwargs,
1039
- )
1040
- if use_temporal_module
1041
- else None
1042
- )
1043
- for _ in range(num_layers)
1044
- ]
1045
- )
1046
-
1047
- if add_upsample:
1048
- self.upsamplers = nn.ModuleList(
1049
- [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
1050
- )
1051
- else:
1052
- self.upsamplers = None
1053
-
1054
- self.gradient_checkpointing = False
1055
-
1056
- def forward(
1057
- self,
1058
- hidden_states,
1059
- res_hidden_states_tuple,
1060
- temb=None,
1061
- upsample_size=None,
1062
- encoder_hidden_states=None,
1063
- skip_mm=False,
1064
- ):
1065
- if isinstance(encoder_hidden_states, list):
1066
- encoder_hidden_states, motion_hidden_states = encoder_hidden_states
1067
- else:
1068
- motion_hidden_states = encoder_hidden_states
1069
- for resnet, motion_module, temporal_module in zip(self.resnets, self.motion_modules, self.temporal_modules):
1070
- # pop res hidden states
1071
- res_hidden_states = res_hidden_states_tuple[-1]
1072
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1073
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1074
-
1075
- # print(f"UpBlock3D {self.gradient_checkpointing = }")
1076
- if self.training and self.gradient_checkpointing:
1077
-
1078
- def create_custom_forward(module):
1079
- def custom_forward(*inputs):
1080
- return module(*inputs)
1081
-
1082
- return custom_forward
1083
-
1084
- hidden_states = torch.utils.checkpoint.checkpoint(
1085
- create_custom_forward(resnet), hidden_states, temb
1086
- )
1087
- if (motion_module is not None) and not skip_mm:
1088
- hidden_states = torch.utils.checkpoint.checkpoint(
1089
- create_custom_forward(motion_module),
1090
- hidden_states,
1091
- temb,
1092
- motion_hidden_states,
1093
- )
1094
- if (temporal_module is not None) and not skip_mm:
1095
- hidden_states = torch.utils.checkpoint.checkpoint(
1096
- create_custom_forward(temporal_module),
1097
- hidden_states.requires_grad_(),
1098
- temb,
1099
- None,
1100
- )
1101
-
1102
- else:
1103
- hidden_states = resnet(hidden_states, temb)
1104
- hidden_states = (
1105
- motion_module(
1106
- hidden_states, temb, encoder_hidden_states=motion_hidden_states
1107
- )
1108
- if (motion_module is not None) and not skip_mm
1109
- else hidden_states
1110
- )
1111
- hidden_states = (
1112
- temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
1113
- if (temporal_module is not None) and not skip_mm
1114
- else hidden_states
1115
- )
1116
-
1117
- if self.upsamplers is not None:
1118
- for upsampler in self.upsamplers:
1119
- hidden_states = upsampler(hidden_states, upsample_size)
1120
-
1121
- return hidden_states