| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.utils.checkpoint |
| |
|
| | from ...configuration_utils import ConfigMixin, register_to_config |
| | from ...loaders import UNet2DConditionLoadersMixin |
| | from ...utils import BaseOutput, logging |
| | from ..activations import get_activation |
| | from ..attention_processor import ( |
| | ADDED_KV_ATTENTION_PROCESSORS, |
| | CROSS_ATTENTION_PROCESSORS, |
| | Attention, |
| | AttentionProcessor, |
| | AttnAddedKVProcessor, |
| | AttnProcessor, |
| | FusedAttnProcessor2_0, |
| | ) |
| | from ..embeddings import TimestepEmbedding, Timesteps |
| | from ..modeling_utils import ModelMixin |
| | from ..transformers.transformer_temporal import TransformerTemporalModel |
| | from .unet_3d_blocks import ( |
| | CrossAttnDownBlock3D, |
| | CrossAttnUpBlock3D, |
| | DownBlock3D, |
| | UNetMidBlock3DCrossAttn, |
| | UpBlock3D, |
| | get_down_block, |
| | get_up_block, |
| | ) |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class UNet3DConditionOutput(BaseOutput): |
| | """ |
| | The output of [`UNet3DConditionModel`]. |
| | |
| | Args: |
| | sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): |
| | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. |
| | """ |
| |
|
| | sample: torch.Tensor |
| |
|
| |
|
| | class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): |
| | r""" |
| | A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample |
| | shaped output. |
| | |
| | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented |
| | for all models (such as downloading or saving). |
| | |
| | Parameters: |
| | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): |
| | Height and width of input/output sample. |
| | in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. |
| | out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. |
| | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`): |
| | The tuple of downsample blocks to use. |
| | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`): |
| | The tuple of upsample blocks to use. |
| | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): |
| | The tuple of output channels for each block. |
| | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. |
| | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. |
| | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. |
| | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. |
| | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. |
| | If `None`, normalization and activation layers is skipped in post-processing. |
| | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. |
| | cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features. |
| | attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads. |
| | num_attention_heads (`int`, *optional*): The number of attention heads. |
| | time_cond_proj_dim (`int`, *optional*, defaults to `None`): |
| | The dimension of `cond_proj` layer in the timestep embedding. |
| | """ |
| |
|
| | _supports_gradient_checkpointing = False |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | sample_size: Optional[int] = None, |
| | in_channels: int = 4, |
| | out_channels: int = 4, |
| | down_block_types: Tuple[str, ...] = ( |
| | "CrossAttnDownBlock3D", |
| | "CrossAttnDownBlock3D", |
| | "CrossAttnDownBlock3D", |
| | "DownBlock3D", |
| | ), |
| | up_block_types: Tuple[str, ...] = ( |
| | "UpBlock3D", |
| | "CrossAttnUpBlock3D", |
| | "CrossAttnUpBlock3D", |
| | "CrossAttnUpBlock3D", |
| | ), |
| | block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), |
| | layers_per_block: int = 2, |
| | downsample_padding: int = 1, |
| | mid_block_scale_factor: float = 1, |
| | act_fn: str = "silu", |
| | norm_num_groups: Optional[int] = 32, |
| | norm_eps: float = 1e-5, |
| | cross_attention_dim: int = 1024, |
| | attention_head_dim: Union[int, Tuple[int]] = 64, |
| | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, |
| | time_cond_proj_dim: Optional[int] = None, |
| | ): |
| | super().__init__() |
| |
|
| | self.sample_size = sample_size |
| |
|
| | if num_attention_heads is not None: |
| | raise NotImplementedError( |
| | "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | num_attention_heads = num_attention_heads or attention_head_dim |
| |
|
| | |
| | if len(down_block_types) != len(up_block_types): |
| | raise ValueError( |
| | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." |
| | ) |
| |
|
| | if len(block_out_channels) != len(down_block_types): |
| | raise ValueError( |
| | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." |
| | ) |
| |
|
| | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): |
| | raise ValueError( |
| | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." |
| | ) |
| |
|
| | |
| | conv_in_kernel = 3 |
| | conv_out_kernel = 3 |
| | conv_in_padding = (conv_in_kernel - 1) // 2 |
| | self.conv_in = nn.Conv2d( |
| | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding |
| | ) |
| |
|
| | |
| | time_embed_dim = block_out_channels[0] * 4 |
| | self.time_proj = Timesteps(block_out_channels[0], True, 0) |
| | timestep_input_dim = block_out_channels[0] |
| |
|
| | self.time_embedding = TimestepEmbedding( |
| | timestep_input_dim, |
| | time_embed_dim, |
| | act_fn=act_fn, |
| | cond_proj_dim=time_cond_proj_dim, |
| | ) |
| |
|
| | self.transformer_in = TransformerTemporalModel( |
| | num_attention_heads=8, |
| | attention_head_dim=attention_head_dim, |
| | in_channels=block_out_channels[0], |
| | num_layers=1, |
| | norm_num_groups=norm_num_groups, |
| | ) |
| |
|
| | |
| | self.down_blocks = nn.ModuleList([]) |
| | self.up_blocks = nn.ModuleList([]) |
| |
|
| | if isinstance(num_attention_heads, int): |
| | num_attention_heads = (num_attention_heads,) * len(down_block_types) |
| |
|
| | |
| | output_channel = block_out_channels[0] |
| | for i, down_block_type in enumerate(down_block_types): |
| | input_channel = output_channel |
| | output_channel = block_out_channels[i] |
| | is_final_block = i == len(block_out_channels) - 1 |
| |
|
| | down_block = get_down_block( |
| | down_block_type, |
| | num_layers=layers_per_block, |
| | in_channels=input_channel, |
| | out_channels=output_channel, |
| | temb_channels=time_embed_dim, |
| | add_downsample=not is_final_block, |
| | resnet_eps=norm_eps, |
| | resnet_act_fn=act_fn, |
| | resnet_groups=norm_num_groups, |
| | cross_attention_dim=cross_attention_dim, |
| | num_attention_heads=num_attention_heads[i], |
| | downsample_padding=downsample_padding, |
| | dual_cross_attention=False, |
| | ) |
| | self.down_blocks.append(down_block) |
| |
|
| | |
| | self.mid_block = UNetMidBlock3DCrossAttn( |
| | in_channels=block_out_channels[-1], |
| | temb_channels=time_embed_dim, |
| | resnet_eps=norm_eps, |
| | resnet_act_fn=act_fn, |
| | output_scale_factor=mid_block_scale_factor, |
| | cross_attention_dim=cross_attention_dim, |
| | num_attention_heads=num_attention_heads[-1], |
| | resnet_groups=norm_num_groups, |
| | dual_cross_attention=False, |
| | ) |
| |
|
| | |
| | self.num_upsamplers = 0 |
| |
|
| | |
| | reversed_block_out_channels = list(reversed(block_out_channels)) |
| | reversed_num_attention_heads = list(reversed(num_attention_heads)) |
| |
|
| | output_channel = reversed_block_out_channels[0] |
| | for i, up_block_type in enumerate(up_block_types): |
| | is_final_block = i == len(block_out_channels) - 1 |
| |
|
| | prev_output_channel = output_channel |
| | output_channel = reversed_block_out_channels[i] |
| | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] |
| |
|
| | |
| | if not is_final_block: |
| | add_upsample = True |
| | self.num_upsamplers += 1 |
| | else: |
| | add_upsample = False |
| |
|
| | up_block = get_up_block( |
| | up_block_type, |
| | num_layers=layers_per_block + 1, |
| | in_channels=input_channel, |
| | out_channels=output_channel, |
| | prev_output_channel=prev_output_channel, |
| | temb_channels=time_embed_dim, |
| | add_upsample=add_upsample, |
| | resnet_eps=norm_eps, |
| | resnet_act_fn=act_fn, |
| | resnet_groups=norm_num_groups, |
| | cross_attention_dim=cross_attention_dim, |
| | num_attention_heads=reversed_num_attention_heads[i], |
| | dual_cross_attention=False, |
| | resolution_idx=i, |
| | ) |
| | self.up_blocks.append(up_block) |
| | prev_output_channel = output_channel |
| |
|
| | |
| | if norm_num_groups is not None: |
| | self.conv_norm_out = nn.GroupNorm( |
| | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps |
| | ) |
| | self.conv_act = get_activation("silu") |
| | else: |
| | self.conv_norm_out = None |
| | self.conv_act = None |
| |
|
| | conv_out_padding = (conv_out_kernel - 1) // 2 |
| | self.conv_out = nn.Conv2d( |
| | block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding |
| | ) |
| |
|
| | @property |
| | |
| | def attn_processors(self) -> Dict[str, AttentionProcessor]: |
| | r""" |
| | Returns: |
| | `dict` of attention processors: A dictionary containing all attention processors used in the model with |
| | indexed by its weight name. |
| | """ |
| | |
| | processors = {} |
| |
|
| | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
| | if hasattr(module, "get_processor"): |
| | processors[f"{name}.processor"] = module.get_processor() |
| |
|
| | for sub_name, child in module.named_children(): |
| | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
| |
|
| | return processors |
| |
|
| | for name, module in self.named_children(): |
| | fn_recursive_add_processors(name, module, processors) |
| |
|
| | return processors |
| |
|
| | |
| | def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: |
| | r""" |
| | Enable sliced attention computation. |
| | |
| | When this option is enabled, the attention module splits the input tensor in slices to compute attention in |
| | several steps. This is useful for saving some memory in exchange for a small decrease in speed. |
| | |
| | Args: |
| | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): |
| | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If |
| | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is |
| | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` |
| | must be a multiple of `slice_size`. |
| | """ |
| | sliceable_head_dims = [] |
| |
|
| | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): |
| | if hasattr(module, "set_attention_slice"): |
| | sliceable_head_dims.append(module.sliceable_head_dim) |
| |
|
| | for child in module.children(): |
| | fn_recursive_retrieve_sliceable_dims(child) |
| |
|
| | |
| | for module in self.children(): |
| | fn_recursive_retrieve_sliceable_dims(module) |
| |
|
| | num_sliceable_layers = len(sliceable_head_dims) |
| |
|
| | if slice_size == "auto": |
| | |
| | |
| | slice_size = [dim // 2 for dim in sliceable_head_dims] |
| | elif slice_size == "max": |
| | |
| | slice_size = num_sliceable_layers * [1] |
| |
|
| | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size |
| |
|
| | if len(slice_size) != len(sliceable_head_dims): |
| | raise ValueError( |
| | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" |
| | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." |
| | ) |
| |
|
| | for i in range(len(slice_size)): |
| | size = slice_size[i] |
| | dim = sliceable_head_dims[i] |
| | if size is not None and size > dim: |
| | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") |
| |
|
| | |
| | |
| | |
| | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): |
| | if hasattr(module, "set_attention_slice"): |
| | module.set_attention_slice(slice_size.pop()) |
| |
|
| | for child in module.children(): |
| | fn_recursive_set_attention_slice(child, slice_size) |
| |
|
| | reversed_slice_size = list(reversed(slice_size)) |
| | for module in self.children(): |
| | fn_recursive_set_attention_slice(module, reversed_slice_size) |
| |
|
| | |
| | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
| | r""" |
| | Sets the attention processor to use to compute attention. |
| | |
| | Parameters: |
| | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
| | The instantiated processor class or a dictionary of processor classes that will be set as the processor |
| | for **all** `Attention` layers. |
| | |
| | If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
| | processor. This is strongly recommended when setting trainable attention processors. |
| | |
| | """ |
| | count = len(self.attn_processors.keys()) |
| |
|
| | if isinstance(processor, dict) and len(processor) != count: |
| | raise ValueError( |
| | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
| | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
| | ) |
| |
|
| | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
| | if hasattr(module, "set_processor"): |
| | if not isinstance(processor, dict): |
| | module.set_processor(processor) |
| | else: |
| | module.set_processor(processor.pop(f"{name}.processor")) |
| |
|
| | for sub_name, child in module.named_children(): |
| | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
| |
|
| | for name, module in self.named_children(): |
| | fn_recursive_attn_processor(name, module, processor) |
| |
|
| | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: |
| | """ |
| | Sets the attention processor to use [feed forward |
| | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). |
| | |
| | Parameters: |
| | chunk_size (`int`, *optional*): |
| | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually |
| | over each tensor of dim=`dim`. |
| | dim (`int`, *optional*, defaults to `0`): |
| | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) |
| | or dim=1 (sequence length). |
| | """ |
| | if dim not in [0, 1]: |
| | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") |
| |
|
| | |
| | chunk_size = chunk_size or 1 |
| |
|
| | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): |
| | if hasattr(module, "set_chunk_feed_forward"): |
| | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) |
| |
|
| | for child in module.children(): |
| | fn_recursive_feed_forward(child, chunk_size, dim) |
| |
|
| | for module in self.children(): |
| | fn_recursive_feed_forward(module, chunk_size, dim) |
| |
|
| | def disable_forward_chunking(self): |
| | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): |
| | if hasattr(module, "set_chunk_feed_forward"): |
| | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) |
| |
|
| | for child in module.children(): |
| | fn_recursive_feed_forward(child, chunk_size, dim) |
| |
|
| | for module in self.children(): |
| | fn_recursive_feed_forward(module, None, 0) |
| |
|
| | |
| | def set_default_attn_processor(self): |
| | """ |
| | Disables custom attention processors and sets the default attention implementation. |
| | """ |
| | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
| | processor = AttnAddedKVProcessor() |
| | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
| | processor = AttnProcessor() |
| | else: |
| | raise ValueError( |
| | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" |
| | ) |
| |
|
| | self.set_attn_processor(processor) |
| |
|
| | def _set_gradient_checkpointing(self, module, value: bool = False) -> None: |
| | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): |
| | module.gradient_checkpointing = value |
| |
|
| | |
| | def enable_freeu(self, s1, s2, b1, b2): |
| | r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. |
| | |
| | The suffixes after the scaling factors represent the stage blocks where they are being applied. |
| | |
| | Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that |
| | are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. |
| | |
| | Args: |
| | s1 (`float`): |
| | Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to |
| | mitigate the "oversmoothing effect" in the enhanced denoising process. |
| | s2 (`float`): |
| | Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to |
| | mitigate the "oversmoothing effect" in the enhanced denoising process. |
| | b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. |
| | b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. |
| | """ |
| | for i, upsample_block in enumerate(self.up_blocks): |
| | setattr(upsample_block, "s1", s1) |
| | setattr(upsample_block, "s2", s2) |
| | setattr(upsample_block, "b1", b1) |
| | setattr(upsample_block, "b2", b2) |
| |
|
| | |
| | def disable_freeu(self): |
| | """Disables the FreeU mechanism.""" |
| | freeu_keys = {"s1", "s2", "b1", "b2"} |
| | for i, upsample_block in enumerate(self.up_blocks): |
| | for k in freeu_keys: |
| | if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: |
| | setattr(upsample_block, k, None) |
| |
|
| | |
| | def fuse_qkv_projections(self): |
| | """ |
| | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
| | are fused. For cross-attention modules, key and value projection matrices are fused. |
| | |
| | <Tip warning={true}> |
| | |
| | This API is 🧪 experimental. |
| | |
| | </Tip> |
| | """ |
| | self.original_attn_processors = None |
| |
|
| | for _, attn_processor in self.attn_processors.items(): |
| | if "Added" in str(attn_processor.__class__.__name__): |
| | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") |
| |
|
| | self.original_attn_processors = self.attn_processors |
| |
|
| | for module in self.modules(): |
| | if isinstance(module, Attention): |
| | module.fuse_projections(fuse=True) |
| |
|
| | self.set_attn_processor(FusedAttnProcessor2_0()) |
| |
|
| | |
| | def unfuse_qkv_projections(self): |
| | """Disables the fused QKV projection if enabled. |
| | |
| | <Tip warning={true}> |
| | |
| | This API is 🧪 experimental. |
| | |
| | </Tip> |
| | |
| | """ |
| | if self.original_attn_processors is not None: |
| | self.set_attn_processor(self.original_attn_processors) |
| |
|
| | def forward( |
| | self, |
| | sample: torch.Tensor, |
| | timestep: Union[torch.Tensor, float, int], |
| | encoder_hidden_states: torch.Tensor, |
| | class_labels: Optional[torch.Tensor] = None, |
| | timestep_cond: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
| | mid_block_additional_residual: Optional[torch.Tensor] = None, |
| | return_dict: bool = True, |
| | ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: |
| | r""" |
| | The [`UNet3DConditionModel`] forward method. |
| | |
| | Args: |
| | sample (`torch.Tensor`): |
| | The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`. |
| | timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. |
| | encoder_hidden_states (`torch.Tensor`): |
| | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. |
| | class_labels (`torch.Tensor`, *optional*, defaults to `None`): |
| | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. |
| | timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): |
| | Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed |
| | through the `self.time_embedding` layer to obtain the timestep embeddings. |
| | attention_mask (`torch.Tensor`, *optional*, defaults to `None`): |
| | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask |
| | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large |
| | negative values to the attention scores corresponding to "discard" tokens. |
| | cross_attention_kwargs (`dict`, *optional*): |
| | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
| | `self.processor` in |
| | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
| | down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): |
| | A tuple of tensors that if specified are added to the residuals of down unet blocks. |
| | mid_block_additional_residual: (`torch.Tensor`, *optional*): |
| | A tensor that if specified is added to the residual of the middle unet block. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain |
| | tuple. |
| | cross_attention_kwargs (`dict`, *optional*): |
| | A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. |
| | |
| | Returns: |
| | [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: |
| | If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, |
| | otherwise a `tuple` is returned where the first element is the sample tensor. |
| | """ |
| | |
| | |
| | |
| | |
| | default_overall_up_factor = 2**self.num_upsamplers |
| |
|
| | |
| | forward_upsample_size = False |
| | upsample_size = None |
| |
|
| | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): |
| | logger.info("Forward upsample size to force interpolation output size.") |
| | forward_upsample_size = True |
| |
|
| | |
| | if attention_mask is not None: |
| | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
| | attention_mask = attention_mask.unsqueeze(1) |
| |
|
| | |
| | timesteps = timestep |
| | if not torch.is_tensor(timesteps): |
| | |
| | |
| | is_mps = sample.device.type == "mps" |
| | if isinstance(timestep, float): |
| | dtype = torch.float32 if is_mps else torch.float64 |
| | else: |
| | dtype = torch.int32 if is_mps else torch.int64 |
| | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
| | elif len(timesteps.shape) == 0: |
| | timesteps = timesteps[None].to(sample.device) |
| |
|
| | |
| | num_frames = sample.shape[2] |
| | timesteps = timesteps.expand(sample.shape[0]) |
| |
|
| | t_emb = self.time_proj(timesteps) |
| |
|
| | |
| | |
| | |
| | t_emb = t_emb.to(dtype=self.dtype) |
| |
|
| | emb = self.time_embedding(t_emb, timestep_cond) |
| | emb = emb.repeat_interleave(repeats=num_frames, dim=0) |
| | encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) |
| |
|
| | |
| | sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) |
| | sample = self.conv_in(sample) |
| |
|
| | sample = self.transformer_in( |
| | sample, |
| | num_frames=num_frames, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | down_block_res_samples = (sample,) |
| | for downsample_block in self.down_blocks: |
| | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: |
| | sample, res_samples = downsample_block( |
| | hidden_states=sample, |
| | temb=emb, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=attention_mask, |
| | num_frames=num_frames, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | ) |
| | else: |
| | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) |
| |
|
| | down_block_res_samples += res_samples |
| |
|
| | if down_block_additional_residuals is not None: |
| | new_down_block_res_samples = () |
| |
|
| | for down_block_res_sample, down_block_additional_residual in zip( |
| | down_block_res_samples, down_block_additional_residuals |
| | ): |
| | down_block_res_sample = down_block_res_sample + down_block_additional_residual |
| | new_down_block_res_samples += (down_block_res_sample,) |
| |
|
| | down_block_res_samples = new_down_block_res_samples |
| |
|
| | |
| | if self.mid_block is not None: |
| | sample = self.mid_block( |
| | sample, |
| | emb, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=attention_mask, |
| | num_frames=num_frames, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | ) |
| |
|
| | if mid_block_additional_residual is not None: |
| | sample = sample + mid_block_additional_residual |
| |
|
| | |
| | for i, upsample_block in enumerate(self.up_blocks): |
| | is_final_block = i == len(self.up_blocks) - 1 |
| |
|
| | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
| | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
| |
|
| | |
| | |
| | if not is_final_block and forward_upsample_size: |
| | upsample_size = down_block_res_samples[-1].shape[2:] |
| |
|
| | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
| | sample = upsample_block( |
| | hidden_states=sample, |
| | temb=emb, |
| | res_hidden_states_tuple=res_samples, |
| | encoder_hidden_states=encoder_hidden_states, |
| | upsample_size=upsample_size, |
| | attention_mask=attention_mask, |
| | num_frames=num_frames, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | ) |
| | else: |
| | sample = upsample_block( |
| | hidden_states=sample, |
| | temb=emb, |
| | res_hidden_states_tuple=res_samples, |
| | upsample_size=upsample_size, |
| | num_frames=num_frames, |
| | ) |
| |
|
| | |
| | if self.conv_norm_out: |
| | sample = self.conv_norm_out(sample) |
| | sample = self.conv_act(sample) |
| |
|
| | sample = self.conv_out(sample) |
| |
|
| | |
| | sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) |
| |
|
| | if not return_dict: |
| | return (sample,) |
| |
|
| | return UNet3DConditionOutput(sample=sample) |
| |
|