from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalControlNetMixin from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import ( TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.controlnet import ( ControlNetConditioningEmbedding, ControlNetOutput, ControlNetModel, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name class UNetDec_ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): """ A ControlNet model. Args: in_channels (`int`, defaults to 4): The number of channels in the input sample. flip_sin_to_cos (`bool`, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, defaults to 2): The number of layers per block. downsample_padding (`int`, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, defaults to 1): The scale factor to use for the mid block. act_fn (`str`, defaults to "silu"): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If None, normalization and activation layers is skipped in post-processing. norm_eps (`float`, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): The dimension of the attention heads. use_linear_projection (`bool`, defaults to `False`): class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. num_class_embeds (`int`, *optional*, defaults to 0): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. upcast_attention (`bool`, defaults to `False`): resnet_time_scale_shift (`str`, defaults to `"default"`): Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `conditioning_embedding` layer. global_pool_conditions (`bool`, defaults to `False`): """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 4, conditioning_channels: int = 3, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), up_block_types: Tuple[str] = ( "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, addition_embed_type_num_heads=64, ): super().__init__() # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len( only_cross_attention ) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) # # input # conv_in_kernel = 3 # conv_in_padding = (conv_in_kernel - 1) // 2 # self.conv_in = nn.Conv2d( # in_channels, # block_out_channels[0], # kernel_size=conv_in_kernel, # padding=conv_in_padding, # ) self.conv_in = None # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info( "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined." ) if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding( projection_class_embeddings_input_dim, time_embed_dim ) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads, ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim, ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type is not None: raise ValueError( f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'." ) # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) self.down_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # # down # output_channel = block_out_channels[0] # controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) # controlnet_block = zero_module(controlnet_block) # self.controlnet_down_blocks.append(controlnet_block) # for i, down_block_type in enumerate(down_block_types): # input_channel = output_channel # output_channel = block_out_channels[i] # is_final_block = i == len(block_out_channels) - 1 # for _ in range(layers_per_block): # controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) # controlnet_block = zero_module(controlnet_block) # self.controlnet_down_blocks.append(controlnet_block) # if not is_final_block: # controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) # controlnet_block = zero_module(controlnet_block) # self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channel = block_out_channels[-1] controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) # up self.controlnet_up_blocks = nn.ModuleList([]) self.num_upsamplers = 0 self.up_blocks = nn.ModuleList([]) reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=reversed_num_attention_heads[i], use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_head_dim=( attention_head_dim[i] if attention_head_dim[i] is not None else output_channel ), ) self.up_blocks.append(up_block) prev_output_channel = output_channel # if i>0: # 因为我们只输出transformer相关的,而第一级的upblock是纯conv for _ in range(layers_per_block + 1): controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_up_blocks.append(controlnet_block) @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, ): r""" Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied where applicable. """ transformer_layers_per_block = ( unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 ) encoder_hid_dim = ( unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None ) encoder_hid_dim_type = ( unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None ) addition_embed_type = ( unet.config.addition_embed_type if "addition_embed_type" in unet.config else None ) addition_time_embed_dim = ( unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None ) controlnet = cls( encoder_hid_dim=encoder_hid_dim, encoder_hid_dim_type=encoder_hid_dim_type, addition_embed_type=addition_embed_type, addition_time_embed_dim=addition_time_embed_dim, transformer_layers_per_block=transformer_layers_per_block, in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, attention_head_dim=unet.config.attention_head_dim, num_attention_heads=unet.config.num_attention_heads, use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, ).requires_grad_(False) if load_weights_from_unet: # controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) if controlnet.class_embedding: controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) # controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) # controlnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False) # TODO:将各个upblock第一层融合层使用weight进行部分初始化 src_tensors = dict(named_params_and_buffers(unet.up_blocks)) for name, tensor in named_params_and_buffers(controlnet.up_blocks): assert (name in src_tensors), name try: # print('Successfully initializing ControlNet!', name, tensor.shape, src_tensors[name].shape) tensor.copy_(src_tensors[name].detach()) except: # print('Mismatch occured in initializing ControlNet!', name, tensor.shape, src_tensors[name].shape) # TODO: 确保所有upblock参数有初始化 if tensor.dim() == 1: tensor.copy_(src_tensors[name].detach()[:tensor.shape[0]]) else: tensor.copy_(src_tensors[name].detach()[:, :tensor.shape[1]]) return controlnet def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, # controlnet_cond: torch.FloatTensor, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, only_return_transformer_layers_out: bool = False, ) -> Union[ControlNetOutput, Tuple]: """ The [`ControlNetModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor. timestep (`Union[torch.Tensor, float, int]`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. controlnet_cond (`torch.FloatTensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`): added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. guess_mode (`bool`, defaults to `False`): In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: [`~models.controlnet.ControlNetOutput`] **or** `tuple`: If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError( "class_labels should be provided when num_class_embeds > 0" ) if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb if self.config.addition_embed_type is not None: if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) emb = emb + aug_emb if aug_emb is not None else emb if self.conv_in is not None: sample = self.conv_in(sample) # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, ) mid_block_res_sample = self.controlnet_mid_block(sample) # 5. up out_block_res_samples = [] for i, upsample_block in enumerate(self.up_blocks): if (hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention): sample, res_samples = upsample_block( hidden_states=sample, temb=emb, # res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, ) else: sample, res_samples = upsample_block( hidden_states=sample, temb=emb, # res_hidden_states_tuple=res_samples, ) # print('Out', i, [res_sample.shape for res_sample in res_samples]) # if i in range(1, len(self.up_blocks)-1): # res_samples[-1] = res_samples[-1][:, :sample.shape[1]//2] # 为了适配controlNet输出的shape out_block_res_samples += res_samples # 5. Control net blocks controlnet_up_block_res_samples = () assert len(out_block_res_samples) == len(self.controlnet_up_blocks), (len(out_block_res_samples), len(self.controlnet_up_blocks)) for i, controlnet_block in enumerate(self.controlnet_up_blocks): # zero proj out # if only_return_transformer_layers_out and i < 3: # 第一个upblock是纯conv的 # continue up_block_res_sample = controlnet_block(out_block_res_samples[i]) controlnet_up_block_res_samples = controlnet_up_block_res_samples + (up_block_res_sample,) up_block_res_samples = controlnet_up_block_res_samples # 6. scaling if guess_mode and not self.config.global_pool_conditions: raise NotImplementedError scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: up_block_res_samples = [sample * conditioning_scale for sample in up_block_res_samples] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: up_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in up_block_res_samples] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if only_return_transformer_layers_out: down_block_res_samples = [sample for i, sample in enumerate(up_block_res_samples) if not i % 3== 2][::-1] else: down_block_res_samples = list(reversed(up_block_res_samples)) if not return_dict: return (down_block_res_samples, mid_block_res_sample, up_block_res_samples) return ControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, up_block_res_samples=up_block_res_samples, ) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module def named_params_and_buffers(module): assert isinstance(module, torch.nn.Module) return list(module.named_parameters()) + list(module.named_buffers()) from src.models.unet_2d_blocks import UpBlock2D, CrossAttnUpBlock2D from diffusers.models.resnet import ResnetBlock2D from diffusers.utils import is_torch_version class UpBlock2D_woskip(UpBlock2D): def __init__( self, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, ): super(UpBlock2D_woskip, self).__init__( in_channels=0, prev_output_channel=prev_output_channel, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, num_layers=num_layers, resnet_eps=resnet_eps, resnet_time_scale_shift=resnet_time_scale_shift, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_pre_norm=resnet_pre_norm, output_scale_factor=output_scale_factor, add_upsample=add_upsample, ) resnets = [] for i in range(num_layers): resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None, upsample_size=None, scale: float = 1.0): output_states = [] for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb, scale=scale) output_states = output_states + [hidden_states,] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size, scale=scale) return hidden_states, output_states class CrossAttnUpBlock2D_woskip(CrossAttnUpBlock2D): def __init__( self, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, attention_type="default", ): super(CrossAttnUpBlock2D_woskip, self).__init__( in_channels=0, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, dropout=dropout, num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, resnet_eps=resnet_eps, resnet_time_scale_shift=resnet_time_scale_shift, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_pre_norm=resnet_pre_norm, num_attention_heads=num_attention_heads, cross_attention_dim=cross_attention_dim, output_scale_factor=output_scale_factor, add_upsample=add_upsample, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) resnets = [] for i in range(num_layers): resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): lora_scale = ( cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 ) output_states = [] for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = ( {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] output_states = output_states + [hidden_states,] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) return hidden_states, output_states def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", attention_type="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, upsample_type=None, dropout=0.0, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warn( f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": return UpBlock2D_woskip( num_layers=num_layers, # in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D_woskip( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, # in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, ) raise ValueError(f"{up_block_type} does not exist.")