Spaces:
Configuration error
Configuration error
| # Copy from diffusers.models.unets.unet_2d_condition.py | |
| # Copyright 2024 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| from diffusers.utils import logging | |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class ExpandKVUNet2DConditionModel(UNet2DConditionModel): | |
| r""" | |
| A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample | |
| shaped output. | |
| This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented | |
| for all models (such as downloading or saving). | |
| Parameters: | |
| sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): | |
| Height and width of input/output sample. | |
| in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. | |
| out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. | |
| center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. | |
| flip_sin_to_cos (`bool`, *optional*, defaults to `True`): | |
| Whether to flip the sin to cos in the time embedding. | |
| freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. | |
| down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): | |
| The tuple of downsample blocks to use. | |
| mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): | |
| Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or | |
| `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. | |
| up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): | |
| The tuple of upsample blocks to use. | |
| only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): | |
| Whether to include self-attention in the basic transformer blocks, see | |
| [`~models.attention.BasicTransformerBlock`]. | |
| block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): | |
| The tuple of output channels for each block. | |
| layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. | |
| downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. | |
| mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. | |
| dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
| act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. | |
| norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. | |
| If `None`, normalization and activation layers is skipped in post-processing. | |
| norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. | |
| cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): | |
| The dimension of the cross attention features. | |
| transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): | |
| The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for | |
| [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], | |
| [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. | |
| reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): | |
| The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling | |
| blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for | |
| [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], | |
| [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. | |
| encoder_hid_dim (`int`, *optional*, defaults to None): | |
| If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` | |
| dimension to `cross_attention_dim`. | |
| encoder_hid_dim_type (`str`, *optional*, defaults to `None`): | |
| If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text | |
| embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. | |
| attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. | |
| num_attention_heads (`int`, *optional*): | |
| The number of attention heads. If not defined, defaults to `attention_head_dim` | |
| resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config | |
| for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. | |
| class_embed_type (`str`, *optional*, defaults to `None`): | |
| The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, | |
| `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. | |
| addition_embed_type (`str`, *optional*, defaults to `None`): | |
| Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or | |
| "text". "text" will use the `TextTimeEmbedding` layer. | |
| addition_time_embed_dim: (`int`, *optional*, defaults to `None`): | |
| Dimension for the timestep embeddings. | |
| num_class_embeds (`int`, *optional*, defaults to `None`): | |
| Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing | |
| class conditioning with `class_embed_type` equal to `None`. | |
| time_embedding_type (`str`, *optional*, defaults to `positional`): | |
| The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. | |
| time_embedding_dim (`int`, *optional*, defaults to `None`): | |
| An optional override for the dimension of the projected time embedding. | |
| time_embedding_act_fn (`str`, *optional*, defaults to `None`): | |
| Optional activation function to use only once on the time embeddings before they are passed to the rest of | |
| the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. | |
| timestep_post_act (`str`, *optional*, defaults to `None`): | |
| The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. | |
| time_cond_proj_dim (`int`, *optional*, defaults to `None`): | |
| The dimension of `cond_proj` layer in the timestep embedding. | |
| conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. | |
| conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. | |
| projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when | |
| `class_embed_type="projection"`. Required when `class_embed_type="projection"`. | |
| class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time | |
| embeddings with the class embeddings. | |
| mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): | |
| Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If | |
| `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the | |
| `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` | |
| otherwise. | |
| """ | |
| def process_encoder_hidden_states( | |
| self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] | |
| ) -> torch.Tensor: | |
| if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": | |
| encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) | |
| elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": | |
| # Kandinsky 2.1 - style | |
| if "image_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" | |
| ) | |
| image_embeds = added_cond_kwargs.get("image_embeds") | |
| encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) | |
| elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": | |
| # Kandinsky 2.2 - style | |
| if "image_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" | |
| ) | |
| image_embeds = added_cond_kwargs.get("image_embeds") | |
| encoder_hidden_states = self.encoder_hid_proj(image_embeds) | |
| elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": | |
| if "image_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" | |
| ) | |
| image_embeds = added_cond_kwargs.get("image_embeds") | |
| image_embeds = self.encoder_hid_proj(image_embeds) | |
| encoder_hidden_states = (encoder_hidden_states, image_embeds) | |
| elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "instantir": | |
| if "image_embeds" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" | |
| ) | |
| if "extract_kvs" not in added_cond_kwargs: | |
| raise ValueError( | |
| f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" | |
| ) | |
| image_embeds = added_cond_kwargs.get("image_embeds") | |
| image_embeds = self.encoder_hid_proj(image_embeds) | |
| encoder_hidden_states = (encoder_hidden_states, image_embeds) | |
| return encoder_hidden_states | |