Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Any, Dict, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin | |
| from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers | |
| from diffusers.utils.import_utils import is_torch_npu_available | |
| from diffusers.utils.torch_utils import maybe_allow_in_graph | |
| from diffusers.models.attention import FeedForward | |
| from diffusers.models.attention_processor import ( | |
| Attention, | |
| AttentionProcessor, | |
| FluxAttnProcessor2_0, | |
| FluxAttnProcessor2_0_NPU, | |
| FusedFluxAttnProcessor2_0, | |
| ) | |
| from diffusers.models.cache_utils import CacheMixin | |
| from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed | |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class FluxSingleTransformerBlock(nn.Module): | |
| def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): | |
| super().__init__() | |
| self.mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.norm = AdaLayerNormZeroSingle(dim) | |
| self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) | |
| self.act_mlp = nn.GELU(approximate="tanh") | |
| self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) | |
| if is_torch_npu_available(): | |
| deprecation_message = ( | |
| "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " | |
| "should be set explicitly using the `set_attn_processor` method." | |
| ) | |
| deprecate("npu_processor", "0.34.0", deprecation_message) | |
| processor = FluxAttnProcessor2_0_NPU() | |
| else: | |
| processor = FluxAttnProcessor2_0() | |
| self.attn = Attention( | |
| query_dim=dim, | |
| cross_attention_dim=None, | |
| dim_head=attention_head_dim, | |
| heads=num_attention_heads, | |
| out_dim=dim, | |
| bias=True, | |
| processor=processor, | |
| qk_norm="rms_norm", | |
| eps=1e-6, | |
| pre_only=True, | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| temb: torch.Tensor, | |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> torch.Tensor: | |
| residual = hidden_states | |
| norm_hidden_states, gate = self.norm(hidden_states, emb=temb) | |
| mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) | |
| joint_attention_kwargs = joint_attention_kwargs or {} | |
| attn_output = self.attn( | |
| hidden_states=norm_hidden_states, | |
| image_rotary_emb=image_rotary_emb, | |
| **joint_attention_kwargs, | |
| ) | |
| hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) | |
| gate = gate.unsqueeze(1) | |
| hidden_states = gate * self.proj_out(hidden_states) | |
| hidden_states = residual + hidden_states | |
| if hidden_states.dtype == torch.float16: | |
| hidden_states = hidden_states.clip(-65504, 65504) | |
| return hidden_states | |
| class FluxTransformerBlock(nn.Module): | |
| def __init__( | |
| self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 | |
| ): | |
| super().__init__() | |
| self.norm1 = AdaLayerNormZero(dim) | |
| self.norm1_context = AdaLayerNormZero(dim) | |
| self.attn = Attention( | |
| query_dim=dim, | |
| cross_attention_dim=None, | |
| added_kv_proj_dim=dim, | |
| dim_head=attention_head_dim, | |
| heads=num_attention_heads, | |
| out_dim=dim, | |
| context_pre_only=False, | |
| bias=True, | |
| processor=FluxAttnProcessor2_0(), | |
| qk_norm=qk_norm, | |
| eps=eps, | |
| ) | |
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
| self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
| self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) | |
| self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| temb: torch.Tensor, | |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| ## ---- adding the modulation conditioning vector for controlling the strength ---- ## | |
| modulation_condn: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # add logic here for conditioning | |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) | |
| norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( | |
| encoder_hidden_states, emb=temb | |
| ) | |
| # If modulation conditioning is passed, then we will use that to adjust the scale and shift parameters otherwise we will proceed with regular modulation function | |
| if modulation_condn is not None: | |
| modulation_condn = modulation_condn.squeeze(1) | |
| # chunking the modulation space here | |
| modulation_scale, modulation_shift = modulation_condn.chunk(2, dim=1) # dividing the output into two parts, one for scale and another one for shift. | |
| # print("modulation condn shape: {}".format(modulation_condn.shape)) # [1, out_dim] | |
| # adding a delta shift to the shift modulation vector | |
| c_shift_mlp = c_shift_mlp + modulation_shift | |
| # adding a delta scale to the scale modulation vector | |
| c_scale_mlp = c_scale_mlp + modulation_scale | |
| joint_attention_kwargs = joint_attention_kwargs or {} | |
| # Attention. | |
| attention_outputs = self.attn( | |
| hidden_states=norm_hidden_states, | |
| encoder_hidden_states=norm_encoder_hidden_states, | |
| image_rotary_emb=image_rotary_emb, | |
| **joint_attention_kwargs, | |
| ) | |
| if len(attention_outputs) == 2: | |
| attn_output, context_attn_output = attention_outputs | |
| elif len(attention_outputs) == 3: | |
| attn_output, context_attn_output, ip_attn_output = attention_outputs | |
| # Process attention outputs for the `hidden_states`. | |
| attn_output = gate_msa.unsqueeze(1) * attn_output | |
| hidden_states = hidden_states + attn_output | |
| norm_hidden_states = self.norm2(hidden_states) | |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
| ff_output = self.ff(norm_hidden_states) | |
| ff_output = gate_mlp.unsqueeze(1) * ff_output | |
| hidden_states = hidden_states + ff_output | |
| if len(attention_outputs) == 3: | |
| hidden_states = hidden_states + ip_attn_output | |
| # Process attention outputs for the `encoder_hidden_states`. | |
| context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output | |
| encoder_hidden_states = encoder_hidden_states + context_attn_output | |
| norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) | |
| norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] | |
| context_ff_output = self.ff_context(norm_encoder_hidden_states) | |
| encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output | |
| if encoder_hidden_states.dtype == torch.float16: | |
| encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) | |
| return encoder_hidden_states, hidden_states | |
| class FluxTransformer2DModelwithSliderConditioning( | |
| ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin | |
| ): | |
| """ | |
| The Transformer model introduced in Flux. | |
| Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ | |
| Args: | |
| patch_size (`int`, defaults to `1`): | |
| Patch size to turn the input data into small patches. | |
| in_channels (`int`, defaults to `64`): | |
| The number of channels in the input. | |
| out_channels (`int`, *optional*, defaults to `None`): | |
| The number of channels in the output. If not specified, it defaults to `in_channels`. | |
| num_layers (`int`, defaults to `19`): | |
| The number of layers of dual stream DiT blocks to use. | |
| num_single_layers (`int`, defaults to `38`): | |
| The number of layers of single stream DiT blocks to use. | |
| attention_head_dim (`int`, defaults to `128`): | |
| The number of dimensions to use for each attention head. | |
| num_attention_heads (`int`, defaults to `24`): | |
| The number of attention heads to use. | |
| joint_attention_dim (`int`, defaults to `4096`): | |
| The number of dimensions to use for the joint attention (embedding/channel dimension of | |
| `encoder_hidden_states`). | |
| pooled_projection_dim (`int`, defaults to `768`): | |
| The number of dimensions to use for the pooled projection. | |
| guidance_embeds (`bool`, defaults to `False`): | |
| Whether to use guidance embeddings for guidance-distilled variant of the model. | |
| axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): | |
| The dimensions to use for the rotary positional embeddings. | |
| """ | |
| _supports_gradient_checkpointing = True | |
| _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] | |
| _skip_layerwise_casting_patterns = ["pos_embed", "norm"] | |
| _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] | |
| def __init__( | |
| self, | |
| patch_size: int = 1, | |
| in_channels: int = 64, | |
| out_channels: Optional[int] = None, | |
| num_layers: int = 19, | |
| num_single_layers: int = 38, | |
| attention_head_dim: int = 128, | |
| num_attention_heads: int = 24, | |
| joint_attention_dim: int = 4096, | |
| pooled_projection_dim: int = 768, | |
| guidance_embeds: bool = False, | |
| axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), | |
| ): | |
| super().__init__() | |
| self.out_channels = out_channels or in_channels | |
| self.inner_dim = num_attention_heads * attention_head_dim | |
| self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) | |
| text_time_guidance_cls = ( | |
| CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings | |
| ) | |
| self.time_text_embed = text_time_guidance_cls( | |
| embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim | |
| ) | |
| self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) | |
| self.x_embedder = nn.Linear(in_channels, self.inner_dim) | |
| self.transformer_blocks = nn.ModuleList( | |
| [ # we will add conditioning logic in this block for training with modulation space | |
| FluxTransformerBlock( | |
| dim=self.inner_dim, | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=attention_head_dim, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.single_transformer_blocks = nn.ModuleList( | |
| [ | |
| FluxSingleTransformerBlock( | |
| dim=self.inner_dim, | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=attention_head_dim, | |
| ) | |
| for _ in range(num_single_layers) | |
| ] | |
| ) | |
| self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) | |
| self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) | |
| self.gradient_checkpointing = False | |
| # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
| def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
| r""" | |
| Returns: | |
| `dict` of attention processors: A dictionary containing all attention processors used in the model with | |
| indexed by its weight name. | |
| """ | |
| # set recursively | |
| processors = {} | |
| def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
| if hasattr(module, "get_processor"): | |
| processors[f"{name}.processor"] = module.get_processor() | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
| return processors | |
| for name, module in self.named_children(): | |
| fn_recursive_add_processors(name, module, processors) | |
| return processors | |
| # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
| def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
| r""" | |
| Sets the attention processor to use to compute attention. | |
| Parameters: | |
| processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
| The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
| for **all** `Attention` layers. | |
| If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
| processor. This is strongly recommended when setting trainable attention processors. | |
| """ | |
| count = len(self.attn_processors.keys()) | |
| if isinstance(processor, dict) and len(processor) != count: | |
| raise ValueError( | |
| f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
| f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
| ) | |
| def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
| if hasattr(module, "set_processor"): | |
| if not isinstance(processor, dict): | |
| module.set_processor(processor) | |
| else: | |
| module.set_processor(processor.pop(f"{name}.processor")) | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
| for name, module in self.named_children(): | |
| fn_recursive_attn_processor(name, module, processor) | |
| # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 | |
| def fuse_qkv_projections(self): | |
| """ | |
| Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | |
| are fused. For cross-attention modules, key and value projection matrices are fused. | |
| <Tip warning={true}> | |
| This API is 🧪 experimental. | |
| </Tip> | |
| """ | |
| self.original_attn_processors = None | |
| for _, attn_processor in self.attn_processors.items(): | |
| if "Added" in str(attn_processor.__class__.__name__): | |
| raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") | |
| self.original_attn_processors = self.attn_processors | |
| for module in self.modules(): | |
| if isinstance(module, Attention): | |
| module.fuse_projections(fuse=True) | |
| self.set_attn_processor(FusedFluxAttnProcessor2_0()) | |
| # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections | |
| def unfuse_qkv_projections(self): | |
| """Disables the fused QKV projection if enabled. | |
| <Tip warning={true}> | |
| This API is 🧪 experimental. | |
| </Tip> | |
| """ | |
| if self.original_attn_processors is not None: | |
| self.set_attn_processor(self.original_attn_processors) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor = None, | |
| pooled_projections: torch.Tensor = None, | |
| timestep: torch.LongTensor = None, | |
| img_ids: torch.Tensor = None, | |
| txt_ids: torch.Tensor = None, | |
| guidance: torch.Tensor = None, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| controlnet_block_samples=None, | |
| controlnet_single_block_samples=None, | |
| return_dict: bool = True, | |
| controlnet_blocks_repeat: bool = False, | |
| # adding a modulation conditioning, where an embedding is passed that can be used to modulate features of diffusion model | |
| modulation_embeddings: Optional[torch.Tensor] = None, | |
| ) -> Union[torch.Tensor, Transformer2DModelOutput]: | |
| """ | |
| The [`FluxTransformer2DModelwithSliderConditioning`] forward method. | |
| Args: | |
| hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): | |
| Input `hidden_states`. | |
| encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): | |
| Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. | |
| pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected | |
| from the embeddings of input conditions. | |
| timestep ( `torch.LongTensor`): | |
| Used to indicate denoising step. | |
| block_controlnet_hidden_states: (`list` of `torch.Tensor`): | |
| A list of tensors that if specified are added to the residuals of transformer blocks. | |
| joint_attention_kwargs (`dict`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
| `self.processor` in | |
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain | |
| tuple. | |
| Returns: | |
| If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a | |
| `tuple` where the first element is the sample tensor. | |
| """ | |
| # if modulation_embeddings is not None: | |
| # print("working with modulation space conditioning ...") | |
| # print("modulation condn in main transformer call: {}".format(modulation_embeddings.shape)) | |
| if joint_attention_kwargs is not None: | |
| joint_attention_kwargs = joint_attention_kwargs.copy() | |
| lora_scale = joint_attention_kwargs.pop("scale", 1.0) | |
| else: | |
| lora_scale = 1.0 | |
| if USE_PEFT_BACKEND: | |
| # weight the lora layers by setting `lora_scale` for each PEFT layer | |
| scale_lora_layers(self, lora_scale) | |
| else: | |
| if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: | |
| logger.warning( | |
| "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." | |
| ) | |
| hidden_states = self.x_embedder(hidden_states) | |
| timestep = timestep.to(hidden_states.dtype) * 1000 | |
| if guidance is not None: | |
| guidance = guidance.to(hidden_states.dtype) * 1000 | |
| temb = ( | |
| self.time_text_embed(timestep, pooled_projections) | |
| if guidance is None | |
| else self.time_text_embed(timestep, guidance, pooled_projections) | |
| ) | |
| # print("temb shape: {}".format(temb.shape)) # [1, 3072] | |
| # ------------------------ Logic to add the predicted embedding at the root of the modulation branch ------------------------- # | |
| # modulation_embeddings = modulation_embeddings.squeeze(1) | |
| # scale_factor = 10 | |
| # temb = temb + modulation_embeddings * scale_factor | |
| # ---------------------------------------------------------------------------------------------------------------------------- # | |
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) | |
| if txt_ids.ndim == 3: | |
| logger.warning( | |
| "Passing `txt_ids` 3d torch.Tensor is deprecated." | |
| "Please remove the batch dimension and pass it as a 2d torch Tensor" | |
| ) | |
| txt_ids = txt_ids[0] | |
| if img_ids.ndim == 3: | |
| logger.warning( | |
| "Passing `img_ids` 3d torch.Tensor is deprecated." | |
| "Please remove the batch dimension and pass it as a 2d torch Tensor" | |
| ) | |
| img_ids = img_ids[0] | |
| ids = torch.cat((txt_ids, img_ids), dim=0) | |
| image_rotary_emb = self.pos_embed(ids) | |
| if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: | |
| ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") | |
| ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) | |
| joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) | |
| # Iterating over the transformer blocks that process the text and image conditionings separately and then will be combined later | |
| for index_block, block in enumerate(self.transformer_blocks): | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| # TODO: This we have to test and validate later, there is a possibility there is a bug in this and will not work for us. | |
| # # copied from syncd codebase for gradient checkpoint with new arugments | |
| # def create_custom_forward(module, return_dict=None): | |
| # def custom_forward(*inputs): | |
| # if return_dict is not None: | |
| # return module(*inputs, return_dict=return_dict) | |
| # else: | |
| # return module(*inputs) | |
| # return custom_forward | |
| # new_kwargs = { | |
| # "modulation_condn": modulation_condn, | |
| # } | |
| # This line applies gradient checkpointing to the transformer block, allowing for reduced memory usage during training by recomputing intermediate activations in the backward pass. | |
| encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( | |
| block, | |
| hidden_states, | |
| encoder_hidden_states, | |
| temb, | |
| image_rotary_emb, | |
| ) | |
| else: | |
| # adding the modulation conditioning vector in the separate transformer blocks that will use it for adjusting the features | |
| encoder_hidden_states, hidden_states = block( | |
| hidden_states=hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| temb=temb, # the temb vector is modified and a delta based on the new conditioning is added to it. | |
| image_rotary_emb=image_rotary_emb, | |
| joint_attention_kwargs=joint_attention_kwargs, | |
| # adding the modulation conditioning vector that can control the editing quality | |
| modulation_condn = modulation_embeddings, # modulation_embeddings, | Not passing the modulation conditioning as it is already incroported in the temb vector | |
| ) | |
| # controlnet residual | |
| if controlnet_block_samples is not None: | |
| interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) | |
| interval_control = int(np.ceil(interval_control)) | |
| # For Xlabs ControlNet. | |
| if controlnet_blocks_repeat: | |
| hidden_states = ( | |
| hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] | |
| ) | |
| else: | |
| hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] | |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
| # single stream transformer blocks where the processing is happening in a single pass for the text and the image conditionings. | |
| for index_block, block in enumerate(self.single_transformer_blocks): | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| hidden_states = self._gradient_checkpointing_func( | |
| block, | |
| hidden_states, | |
| temb, | |
| image_rotary_emb, | |
| ) | |
| else: | |
| hidden_states = block( | |
| hidden_states=hidden_states, | |
| temb=temb, | |
| image_rotary_emb=image_rotary_emb, | |
| joint_attention_kwargs=joint_attention_kwargs, | |
| ) | |
| # controlnet residual | |
| if controlnet_single_block_samples is not None: | |
| interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) | |
| interval_control = int(np.ceil(interval_control)) | |
| hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( | |
| hidden_states[:, encoder_hidden_states.shape[1] :, ...] | |
| + controlnet_single_block_samples[index_block // interval_control] | |
| ) | |
| hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] | |
| hidden_states = self.norm_out(hidden_states, temb) | |
| output = self.proj_out(hidden_states) | |
| if USE_PEFT_BACKEND: | |
| # remove `lora_scale` from each PEFT layer | |
| unscale_lora_layers(self, lora_scale) | |
| if not return_dict: | |
| return (output,) | |
| return Transformer2DModelOutput(sample=output) | |