diff --git "a/controlnet/net.py" "b/controlnet/net.py" --- "a/controlnet/net.py" +++ "b/controlnet/net.py" @@ -1,1734 +1,1732 @@ - -# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This was modied from the control net repo - - -import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel - -import numpy as np -import torch -from transformers import ( - CLIPTextModel, - CLIPTokenizer, - T5EncoderModel, - T5TokenizerFast, -) - -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin -from diffusers.models.autoencoders import AutoencoderKL -### MERGEING THESE ### -# from src.models.transformer import FluxTransformer2DModel -# from src.models.controlnet_flux import FluxControlNetModel -############# - -########################################## -########### ATTENTION MERGE ############## -########################################## - -import torch -from torch import Tensor, FloatTensor -from torch.nn import functional as F -from einops import rearrange -from diffusers.models.attention_processor import Attention -from diffusers.models.embeddings import apply_rotary_emb - - - -def fa3_sdpa( - q, - k, - v, -): - # flash attention 3 sdpa drop-in replacement - q, k, v = [x.permute(0, 2, 1, 3) for x in [q, k, v]] - out = flash_attn_func(q, k, v)[0] - return out.permute(0, 2, 1, 3) - - -class FluxSingleAttnProcessor3_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn, - hidden_states: Tensor, - encoder_hidden_states: Tensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> Tensor: - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size, _, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = fa3_sdpa(query, key, value) - hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") - - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states - - -class FluxAttnProcessor3_0: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor3_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn, - hidden_states: FloatTensor, - encoder_hidden_states: FloatTensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> FloatTensor: - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q( - encoder_hidden_states_query_proj - ) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k( - encoder_hidden_states_key_proj - ) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = fa3_sdpa(query, key, value) - hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") - - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states, encoder_hidden_states - - - -class FluxFusedSDPAProcessor: - """ - Fused QKV processor using PyTorch's scaled_dot_product_attention. - Uses fused projections but splits for attention computation. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention" - ) - - def __call__( - self, - attn, - hidden_states: FloatTensor, - encoder_hidden_states: FloatTensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> FloatTensor: - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - context_input_ndim = ( - encoder_hidden_states.ndim if encoder_hidden_states is not None else None - ) - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size = ( - encoder_hidden_states.shape[0] - if encoder_hidden_states is not None - else hidden_states.shape[0] - ) - - # Single attention case (no encoder states) - if encoder_hidden_states is None: - # Use fused QKV projection - qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim) - inner_dim = qkv.shape[-1] // 3 - head_dim = inner_dim // attn.heads - seq_len = hidden_states.shape[1] - - # Split and reshape - qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) - query, key, value = qkv.unbind( - dim=2 - ) # Each is (batch, seq_len, heads, head_dim) - - # Transpose to (batch, heads, seq_len, head_dim) - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # Apply norms if needed - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # SDPA - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - ) - - # Reshape back - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states - - # Joint attention case (with encoder states) - else: - # Process self-attention QKV - qkv = attn.to_qkv(hidden_states) - inner_dim = qkv.shape[-1] // 3 - head_dim = inner_dim // attn.heads - seq_len = hidden_states.shape[1] - - qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) - query, key, value = qkv.unbind(dim=2) - - # Transpose to (batch, heads, seq_len, head_dim) - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # Apply norms if needed - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Process encoder QKV - encoder_seq_len = encoder_hidden_states.shape[1] - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - encoder_qkv = encoder_qkv.view( - batch_size, encoder_seq_len, 3, attn.heads, head_dim - ) - encoder_query, encoder_key, encoder_value = encoder_qkv.unbind(dim=2) - - # Transpose to (batch, heads, seq_len, head_dim) - encoder_query = encoder_query.transpose(1, 2) - encoder_key = encoder_key.transpose(1, 2) - encoder_value = encoder_value.transpose(1, 2) - - # Apply encoder norms if needed - if attn.norm_added_q is not None: - encoder_query = attn.norm_added_q(encoder_query) - if attn.norm_added_k is not None: - encoder_key = attn.norm_added_k(encoder_key) - - # Concatenate encoder and self-attention - query = torch.cat([encoder_query, query], dim=2) - key = torch.cat([encoder_key, key], dim=2) - value = torch.cat([encoder_value, value], dim=2) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # SDPA - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - ) - - # Reshape: (batch, heads, seq_len, head_dim) -> (batch, seq_len, heads * head_dim) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - # Split encoder and self outputs - encoder_hidden_states = hidden_states[:, :encoder_seq_len] - hidden_states = hidden_states[:, encoder_seq_len:] - - # Output projections - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) # dropout - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - # Reshape if needed - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states, encoder_hidden_states - - -class FluxSingleFusedSDPAProcessor: - """ - Fused QKV processor for single attention (no encoder states). - Simpler version for self-attention only blocks. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxSingleFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention" - ) - - def __call__( - self, - attn, - hidden_states: Tensor, - encoder_hidden_states: Tensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> Tensor: - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size, seq_len, _ = hidden_states.shape - - # Use fused QKV projection - qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim) - inner_dim = qkv.shape[-1] // 3 - head_dim = inner_dim // attn.heads - - # Split and reshape in one go - qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) - qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, L, D) – still strided - query, key, value = [ - t.contiguous() for t in qkv.unbind(0) # make each view dense - ] - # Now each is (batch, heads, seq_len, head_dim) - - # Apply norms if needed - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # SDPA - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - # Reshape back - hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states - -################################# -##### TRANSFORMER MERGE ######### -################################# - -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin -from diffusers.models.attention import FeedForward -from diffusers.models.attention_processor import ( - Attention, - AttentionProcessor, -) -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.normalization import ( - AdaLayerNormContinuous, - AdaLayerNormZero, - AdaLayerNormZeroSingle, -) -from diffusers.utils import ( - USE_PEFT_BACKEND, - is_torch_version, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.utils.torch_utils import maybe_allow_in_graph -from diffusers.models.embeddings import ( - CombinedTimestepGuidanceTextProjEmbeddings, - CombinedTimestepTextProjEmbeddings, - FluxPosEmbed, -) - -from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -is_flash_attn_available = False - - - -class FluxAttnProcessor2_0: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q( - encoder_hidden_states_query_proj - ) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k( - encoder_hidden_states_key_proj - ) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from diffusers.models.embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if attention_mask is not None: - #print ('Attention Used') - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - attention_mask = (attention_mask > 0).bool() - # Edit 17 - match attn dtype to query d-type - attention_mask = attention_mask.to( - device=hidden_states.device, dtype=query.dtype - ) - - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - dropout_p=0.0, - is_causal=False, - attn_mask=attention_mask, - ) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - return hidden_states - - -def expand_flux_attention_mask( - hidden_states: torch.Tensor, - attn_mask: torch.Tensor, -) -> torch.Tensor: - """ - Expand a mask so that the image is included. - """ - bsz = attn_mask.shape[0] - assert bsz == hidden_states.shape[0] - residual_seq_len = hidden_states.shape[1] - mask_seq_len = attn_mask.shape[1] - - expanded_mask = torch.ones(bsz, residual_seq_len) - expanded_mask[:, :mask_seq_len] = attn_mask - - return expanded_mask - - -@maybe_allow_in_graph -class FluxSingleTransformerBlock(nn.Module): - r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. - """ - - def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): - super().__init__() - self.mlp_hidden_dim = int(dim * mlp_ratio) - - self.norm = AdaLayerNormZeroSingle(dim) - self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) - self.act_mlp = nn.GELU(approximate="tanh") - self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - - processor = FluxAttnProcessor2_0() - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - bias=True, - processor=processor, - qk_norm="rms_norm", - eps=1e-6, - pre_only=True, - ) - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - image_rotary_emb=None, - attention_mask: Optional[torch.Tensor] = None, - ): - residual = hidden_states - norm_hidden_states, gate = self.norm(hidden_states, emb=temb) - mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - - if attention_mask is not None: - attention_mask = expand_flux_attention_mask( - hidden_states, - attention_mask, - ) - - attn_output = self.attn( - hidden_states=norm_hidden_states, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, - ) - - hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) - gate = gate.unsqueeze(1) - hidden_states = gate * self.proj_out(hidden_states) - hidden_states = residual + hidden_states - - if hidden_states.dtype == torch.float16: - hidden_states = hidden_states.clip(-65504, 65504) - - return hidden_states - - -@maybe_allow_in_graph -class FluxTransformerBlock(nn.Module): - r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. - """ - - def __init__( - self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6 - ): - super().__init__() - - self.norm1 = AdaLayerNormZero(dim) - - self.norm1_context = AdaLayerNormZero(dim) - - if hasattr(F, "scaled_dot_product_attention"): - processor = FluxAttnProcessor2_0() - else: - raise ValueError( - "The current PyTorch version does not support the `scaled_dot_product_attention` function." - ) - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - added_kv_proj_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - context_pre_only=False, - bias=True, - processor=processor, - qk_norm=qk_norm, - eps=eps, - ) - - self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward( - dim=dim, dim_out=dim, activation_fn="gelu-approximate" - ) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - image_rotary_emb=None, - attention_mask: Optional[torch.Tensor] = None, - ): - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, emb=temb - ) - - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( - self.norm1_context(encoder_hidden_states, emb=temb) - ) - - if attention_mask is not None: - attention_mask = expand_flux_attention_mask( - torch.cat([encoder_hidden_states, hidden_states], dim=1), - attention_mask, - ) - - # Attention. - attention_outputs = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, - ) - if len(attention_outputs) == 2: - attn_output, context_attn_output = attention_outputs - elif len(attention_outputs) == 3: - attn_output, context_attn_output, ip_attn_output = attention_outputs - - # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = ( - norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ) - - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = hidden_states + ff_output - if len(attention_outputs) == 3: - hidden_states = hidden_states + ip_attn_output - - # Process attention outputs for the `encoder_hidden_states`. - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = ( - norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) - + c_shift_mlp[:, None] - ) - - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = ( - encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - ) - - if encoder_hidden_states.dtype == torch.float16: - encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - - return encoder_hidden_states, hidden_states - - -class LibreFluxTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin -): - """ - The Transformer model introduced in Flux. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Parameters: - patch_size (`int`): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. - num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. - num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. - joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. - guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - patch_size: int = 1, - in_channels: int = 64, - num_layers: int = 19, - num_single_layers: int = 38, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, - guidance_embeds: bool = False, - axes_dims_rope: Tuple[int] = (16, 56, 56), - ): - super().__init__() - self.out_channels = in_channels - self.inner_dim = ( - self.config.num_attention_heads * self.config.attention_head_dim - ) - - self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) - text_time_guidance_cls = ( - CombinedTimestepGuidanceTextProjEmbeddings ### 3 input forward (timestep, guidance, pooled_projection) - if guidance_embeds - else CombinedTimestepTextProjEmbeddings #### 2 input forward (timestep, pooled_projection) - ) - self.time_text_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, - pooled_projection_dim=self.config.pooled_projection_dim, - ) - - self.context_embedder = nn.Linear( - self.config.joint_attention_dim, self.inner_dim - ) - self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - ) - for i in range(self.config.num_layers) - ] - ) - - self.single_transformer_blocks = nn.ModuleList( - [ - FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - ) - for i in range(self.config.num_single_layers) - ] - ) - - self.norm_out = AdaLayerNormContinuous( - self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 - ) - self.proj_out = nn.Linear( - self.inner_dim, patch_size * patch_size * self.out_channels, bias=True - ) - - self.gradient_checkpointing = False - # added for users to disable checkpointing every nth step - self.gradient_checkpointing_interval = None - - def set_gradient_checkpointing_interval(self, value: int): - self.gradient_checkpointing_interval = value - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors( - name: str, - module: torch.nn.Module, - processors: Dict[str, AttentionProcessor], - ): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] - ): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_block_samples=None, - controlnet_single_block_samples=None, - return_dict: bool = True, - attention_mask: Optional[torch.Tensor] = None, - controlnet_blocks_repeat: bool = False, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if ( - joint_attention_kwargs is not None - and joint_attention_kwargs.get("scale", None) is not None - ): - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - hidden_states = self.x_embedder(hidden_states) - - timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 - else: - guidance = None - - #print( self.time_text_embed) - temb = ( - self.time_text_embed(timestep,pooled_projections) - # Edit 1 # Charlie NOT NEEDED - UNDONE - if guidance is None - else self.time_text_embed(timestep, guidance, pooled_projections) - ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - if txt_ids.ndim == 3: - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - img_ids = img_ids[0] - - ids = torch.cat((txt_ids, img_ids), dim=0) - - image_rotary_emb = self.pos_embed(ids) - - # IP adapter - if ( - joint_attention_kwargs is not None - and "ip_adapter_image_embeds" in joint_attention_kwargs - ): - ip_adapter_image_embeds = joint_attention_kwargs.pop( - "ip_adapter_image_embeds" - ) - ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) - joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) - - for index_block, block in enumerate(self.transformer_blocks): - if ( - self.training - and self.gradient_checkpointing - and ( - self.gradient_checkpointing_interval is None - or index_block % self.gradient_checkpointing_interval == 0 - ) - ): - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = ( - {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - ) - encoder_hidden_states, hidden_states = ( - torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - attention_mask, - **ckpt_kwargs, - ) - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, - ) - - # controlnet residual - if controlnet_block_samples is not None: - interval_control = len(self.transformer_blocks) / len( - controlnet_block_samples - ) - interval_control = int(np.ceil(interval_control)) - # For Xlabs ControlNet. - if controlnet_blocks_repeat: - hidden_states = ( - hidden_states - + controlnet_block_samples[ - index_block % len(controlnet_block_samples) - ] - ) - else: - hidden_states = ( - hidden_states - + controlnet_block_samples[index_block // interval_control] - ) - - # Flux places the text tokens in front of the image tokens in the - # sequence. - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - for index_block, block in enumerate(self.single_transformer_blocks): - if ( - self.training - and self.gradient_checkpointing - or ( - self.gradient_checkpointing_interval is not None - and index_block % self.gradient_checkpointing_interval == 0 - ) - ): - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = ( - {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - temb, - image_rotary_emb, - attention_mask, - **ckpt_kwargs, - ) - - else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, - ) - - # controlnet residual - if controlnet_single_block_samples is not None: - interval_control = len(self.single_transformer_blocks) / len( - controlnet_single_block_samples - ) - interval_control = int(np.ceil(interval_control)) - hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( - hidden_states[:, encoder_hidden_states.shape[1] :, ...] - + controlnet_single_block_samples[index_block // interval_control] - ) - - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - - hidden_states = self.norm_out(hidden_states, temb) - output = self.proj_out(hidden_states) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) - -################################### -# END TRANS MERGE -#################################### - -# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This was modied from the control net repo - - -#################################### -##### CONTROL NET MODEL MERGE ###### -#################################### - - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import PeftAdapterMixin -from diffusers.models.attention_processor import AttentionProcessor -from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, zero_module -from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed -from diffusers.models.modeling_outputs import Transformer2DModelOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class FluxControlNetOutput(BaseOutput): - controlnet_block_samples: Tuple[torch.Tensor] - controlnet_single_block_samples: Tuple[torch.Tensor] - - -class LibreFluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - patch_size: int = 1, - in_channels: int = 64, - num_layers: int = 19, - num_single_layers: int = 38, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, - guidance_embeds: bool = False, - axes_dims_rope: List[int] = [16, 56, 56], - num_mode: int = None, - conditioning_embedding_channels: int = None, - ): - super().__init__() - self.out_channels = in_channels - self.inner_dim = num_attention_heads * attention_head_dim - - self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) - - # edit 19 - #text_time_guidance_cls = ( - # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings - #) - - text_time_guidance_cls = CombinedTimestepGuidanceTextProjEmbeddings - text_time_cls = CombinedTimestepTextProjEmbeddings - - self.time_text_embed = text_time_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - self.time_text_guidance_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - - self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) - self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for i in range(num_layers) - ] - ) - - self.single_transformer_blocks = nn.ModuleList( - [ - FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for i in range(num_single_layers) - ] - ) - - # controlnet_blocks - self.controlnet_blocks = nn.ModuleList([]) - for _ in range(len(self.transformer_blocks)): - self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) - - self.controlnet_single_blocks = nn.ModuleList([]) - for _ in range(len(self.single_transformer_blocks)): - self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) - - self.union = num_mode is not None - if self.union: - self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) - - if conditioning_embedding_channels is not None: - self.input_hint_block = ControlNetConditioningEmbedding( - conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) - ) - self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) - else: - self.input_hint_block = None - self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) - - self.gradient_checkpointing = False - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self): - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - @classmethod - def from_transformer( - cls, - transformer, - num_layers: int = 4, - num_single_layers: int = 10, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - load_weights_from_transformer=True, - ): - config = dict(transformer.config) - config["num_layers"] = num_layers - config["num_single_layers"] = num_single_layers - config["attention_head_dim"] = attention_head_dim - config["num_attention_heads"] = num_attention_heads - - controlnet = cls.from_config(config) - - if load_weights_from_transformer: - controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) - controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) - controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) - controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) - controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) - controlnet.single_transformer_blocks.load_state_dict( - transformer.single_transformer_blocks.state_dict(), strict=False - ) - - controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) - - return controlnet - - # Edit 13 Adding attention masking to forward - def forward( - self, - hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - controlnet_mode: torch.Tensor = None, - conditioning_scale: float = 1.0, - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - attention_mask: Optional[torch.Tensor] = None, # <-- 1. ADD ARGUMENT HERE - - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - controlnet_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - controlnet_mode (`torch.Tensor`): - The mode tensor of shape `(batch_size, 1)`. - conditioning_scale (`float`, defaults to `1.0`): - The scale factor for ControlNet outputs. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - hidden_states = self.x_embedder(hidden_states) - - if self.input_hint_block is not None: - controlnet_cond = self.input_hint_block(controlnet_cond) - batch_size, channels, height_pw, width_pw = controlnet_cond.shape - height = height_pw // self.config.patch_size - width = width_pw // self.config.patch_size - controlnet_cond = controlnet_cond.reshape( - batch_size, channels, height, self.config.patch_size, width, self.config.patch_size - ) - controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) - controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) - # add - hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) - - timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 - else: - guidance = None - - #print ('Guidance:', guidance) - temb = ( - self.time_text_embed(timestep, pooled_projections) - if guidance is None - # edit 19 - else self.time_text_guidance_embed(timestep, guidance, pooled_projections) - ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - if self.union: - # union mode - if controlnet_mode is None: - raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") - # union mode emb - controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) - encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) - txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) - - if txt_ids.ndim == 3: - logger.warning( - "Passing `txt_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - logger.warning( - "Passing `img_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - img_ids = img_ids[0] - - ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = self.pos_embed(ids) - - block_samples = () - for index_block, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - attention_mask, # Edit 13 - **ckpt_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, # Edit 13 - - ) - block_samples = block_samples + (hidden_states,) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - single_block_samples = () - for index_block, block in enumerate(self.single_transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - temb, - image_rotary_emb, - attention_mask, # <-- 2. PASS MASK TO GRADIENT CHECKPOINTING # Edit 13 - **ckpt_kwargs, - ) - - else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, # <-- 2. PASS MASK TO BLOCK Edit 13 - - ) - single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) - - # controlnet block - controlnet_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): - block_sample = controlnet_block(block_sample) - controlnet_block_samples = controlnet_block_samples + (block_sample,) - - controlnet_single_block_samples = () - for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): - single_block_sample = controlnet_block(single_block_sample) - controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) - - # scaling - controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] - controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] - - controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples - controlnet_single_block_samples = ( - None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples - ) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (controlnet_block_samples, controlnet_single_block_samples) - - return FluxControlNetOutput( - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - ) - + +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This was modied from the control net repo + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from diffusers.models.autoencoders import AutoencoderKL +### MERGEING THESE ### +# from src.models.transformer import FluxTransformer2DModel +# from src.models.controlnet_flux import FluxControlNetModel +############# + +########################################## +########### ATTENTION MERGE ############## +########################################## + +import torch +from torch import Tensor, FloatTensor +from torch.nn import functional as F +from einops import rearrange +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import apply_rotary_emb + + + +def fa3_sdpa( + q, + k, + v, +): + # flash attention 3 sdpa drop-in replacement + q, k, v = [x.permute(0, 2, 1, 3) for x in [q, k, v]] + out = flash_attn_func(q, k, v)[0] + return out.permute(0, 2, 1, 3) + + +class FluxSingleAttnProcessor3_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn, + hidden_states: Tensor, + encoder_hidden_states: Tensor = None, + attention_mask: FloatTensor = None, + image_rotary_emb: Tensor = None, + ) -> Tensor: + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = fa3_sdpa(query, key, value) + hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + return hidden_states + + +class FluxAttnProcessor3_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessor3_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn, + hidden_states: FloatTensor, + encoder_hidden_states: FloatTensor = None, + attention_mask: FloatTensor = None, + image_rotary_emb: Tensor = None, + ) -> FloatTensor: + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q( + encoder_hidden_states_query_proj + ) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k( + encoder_hidden_states_key_proj + ) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = fa3_sdpa(query, key, value) + hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + return hidden_states, encoder_hidden_states + + + +class FluxFusedSDPAProcessor: + """ + Fused QKV processor using PyTorch's scaled_dot_product_attention. + Uses fused projections but splits for attention computation. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention" + ) + + def __call__( + self, + attn, + hidden_states: FloatTensor, + encoder_hidden_states: FloatTensor = None, + attention_mask: FloatTensor = None, + image_rotary_emb: Tensor = None, + ) -> FloatTensor: + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + context_input_ndim = ( + encoder_hidden_states.ndim if encoder_hidden_states is not None else None + ) + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size = ( + encoder_hidden_states.shape[0] + if encoder_hidden_states is not None + else hidden_states.shape[0] + ) + + # Single attention case (no encoder states) + if encoder_hidden_states is None: + # Use fused QKV projection + qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim) + inner_dim = qkv.shape[-1] // 3 + head_dim = inner_dim // attn.heads + seq_len = hidden_states.shape[1] + + # Split and reshape + qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) + query, key, value = qkv.unbind( + dim=2 + ) # Each is (batch, seq_len, heads, head_dim) + + # Transpose to (batch, heads, seq_len, head_dim) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Apply norms if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # SDPA + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + # Reshape back + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + return hidden_states + + # Joint attention case (with encoder states) + else: + # Process self-attention QKV + qkv = attn.to_qkv(hidden_states) + inner_dim = qkv.shape[-1] // 3 + head_dim = inner_dim // attn.heads + seq_len = hidden_states.shape[1] + + qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) + query, key, value = qkv.unbind(dim=2) + + # Transpose to (batch, heads, seq_len, head_dim) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Apply norms if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Process encoder QKV + encoder_seq_len = encoder_hidden_states.shape[1] + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + encoder_qkv = encoder_qkv.view( + batch_size, encoder_seq_len, 3, attn.heads, head_dim + ) + encoder_query, encoder_key, encoder_value = encoder_qkv.unbind(dim=2) + + # Transpose to (batch, heads, seq_len, head_dim) + encoder_query = encoder_query.transpose(1, 2) + encoder_key = encoder_key.transpose(1, 2) + encoder_value = encoder_value.transpose(1, 2) + + # Apply encoder norms if needed + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + # Concatenate encoder and self-attention + query = torch.cat([encoder_query, query], dim=2) + key = torch.cat([encoder_key, key], dim=2) + value = torch.cat([encoder_value, value], dim=2) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # SDPA + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + # Reshape: (batch, heads, seq_len, head_dim) -> (batch, seq_len, heads * head_dim) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # Split encoder and self outputs + encoder_hidden_states = hidden_states[:, :encoder_seq_len] + hidden_states = hidden_states[:, encoder_seq_len:] + + # Output projections + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) # dropout + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # Reshape if needed + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + return hidden_states, encoder_hidden_states + + +class FluxSingleFusedSDPAProcessor: + """ + Fused QKV processor for single attention (no encoder states). + Simpler version for self-attention only blocks. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxSingleFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention" + ) + + def __call__( + self, + attn, + hidden_states: Tensor, + encoder_hidden_states: Tensor = None, + attention_mask: FloatTensor = None, + image_rotary_emb: Tensor = None, + ) -> Tensor: + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, seq_len, _ = hidden_states.shape + + # Use fused QKV projection + qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim) + inner_dim = qkv.shape[-1] // 3 + head_dim = inner_dim // attn.heads + + # Split and reshape in one go + qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, L, D) – still strided + query, key, value = [ + t.contiguous() for t in qkv.unbind(0) # make each view dense + ] + # Now each is (batch, heads, seq_len, head_dim) + + # Apply norms if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # SDPA + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + # Reshape back + hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") + hidden_states = hidden_states.to(query.dtype) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + return hidden_states + +################################# +##### TRANSFORMER MERGE ######### +################################# + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import ( + Attention, + AttentionProcessor, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import ( + AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, +) +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + FluxPosEmbed, +) + +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +is_flash_attn_available = False + + + +class FluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q( + encoder_hidden_states_query_proj + ) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k( + encoder_hidden_states_key_proj + ) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if attention_mask is not None: + #print ('Attention Used') + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = (attention_mask > 0).bool() + # Edit 17 - match attn dtype to query d-type + attention_mask = attention_mask.to( + device=hidden_states.device, dtype=query.dtype + ) + + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=0.0, + is_causal=False, + attn_mask=attention_mask, + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + return hidden_states + + +def expand_flux_attention_mask( + hidden_states: torch.Tensor, + attn_mask: torch.Tensor, +) -> torch.Tensor: + """ + Expand a mask so that the image is included. + """ + bsz = attn_mask.shape[0] + assert bsz == hidden_states.shape[0] + residual_seq_len = hidden_states.shape[1] + mask_seq_len = attn_mask.shape[1] + + expanded_mask = torch.ones(bsz, residual_seq_len) + expanded_mask[:, :mask_seq_len] = attn_mask + + return expanded_mask + + +@maybe_allow_in_graph +class FluxSingleTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + processor = FluxAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm="rms_norm", + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + attention_mask: Optional[torch.Tensor] = None, + ): + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + if attention_mask is not None: + attention_mask = expand_flux_attention_mask( + hidden_states, + attention_mask, + ) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + +@maybe_allow_in_graph +class FluxTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__( + self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + + self.norm1_context = AdaLayerNormZero(dim) + + if hasattr(F, "scaled_dot_product_attention"): + processor = FluxAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward( + dim=dim, dim_out=dim, activation_fn="gelu-approximate" + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + attention_mask: Optional[torch.Tensor] = None, + ): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, emb=temb + ) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( + self.norm1_context(encoder_hidden_states, emb=temb) + ) + + if attention_mask is not None: + attention_mask = expand_flux_attention_mask( + torch.cat([encoder_hidden_states, hidden_states], dim=1), + attention_mask, + ) + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + ) + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + + c_shift_mlp[:, None] + ) + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = ( + encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + ) + + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class LibreFluxTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin +): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = ( + self.config.num_attention_heads * self.config.attention_head_dim + ) + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings ### 3 input forward (timestep, guidance, pooled_projection) + if guidance_embeds + else CombinedTimestepTextProjEmbeddings #### 2 input forward (timestep, pooled_projection) + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, + pooled_projection_dim=self.config.pooled_projection_dim, + ) + + self.context_embedder = nn.Linear( + self.config.joint_attention_dim, self.inner_dim + ) + self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = nn.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True + ) + + self.gradient_checkpointing = False + # added for users to disable checkpointing every nth step + self.gradient_checkpointing_interval = None + + def set_gradient_checkpointing_interval(self, value: int): + self.gradient_checkpointing_interval = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + attention_mask: Optional[torch.Tensor] = None, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if ( + joint_attention_kwargs is not None + and joint_attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + + #print( self.time_text_embed) + temb = ( + self.time_text_embed(timestep,pooled_projections) + # Edit 1 # Charlie NOT NEEDED - UNDONE + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + + image_rotary_emb = self.pos_embed(ids) + + # IP adapter + if ( + joint_attention_kwargs is not None + and "ip_adapter_image_embeds" in joint_attention_kwargs + ): + ip_adapter_image_embeds = joint_attention_kwargs.pop( + "ip_adapter_image_embeds" + ) + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + if ( + self.training + and self.gradient_checkpointing + and ( + self.gradient_checkpointing_interval is None + or index_block % self.gradient_checkpointing_interval == 0 + ) + ): + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + encoder_hidden_states, hidden_states = ( + torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + **ckpt_kwargs, + ) + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len( + controlnet_block_samples + ) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + + controlnet_block_samples[ + index_block % len(controlnet_block_samples) + ] + ) + else: + hidden_states = ( + hidden_states + + controlnet_block_samples[index_block // interval_control] + ) + + # Flux places the text tokens in front of the image tokens in the + # sequence. + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if ( + self.training + and self.gradient_checkpointing + or ( + self.gradient_checkpointing_interval is not None + and index_block % self.gradient_checkpointing_interval == 0 + ) + ): + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + attention_mask, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len( + controlnet_single_block_samples + ) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + +################################### +# END TRANS MERGE +#################################### + +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This was modied from the control net repo + + +#################################### +##### CONTROL NET MODEL MERGE ###### +#################################### + + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, zero_module +from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from diffusers.models.modeling_outputs import Transformer2DModelOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FluxControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[torch.Tensor] + controlnet_single_block_samples: Tuple[torch.Tensor] + + +class LibreFluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: List[int] = [16, 56, 56], + num_mode: int = None, + conditioning_embedding_channels: int = None, + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + + # edit 19 + #text_time_guidance_cls = ( + # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + #) + + text_time_guidance_cls = CombinedTimestepGuidanceTextProjEmbeddings + text_time_cls = CombinedTimestepTextProjEmbeddings + + self.time_text_embed = text_time_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + self.time_text_guidance_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_single_layers) + ] + ) + + # controlnet_blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + + self.controlnet_single_blocks = nn.ModuleList([]) + for _ in range(len(self.single_transformer_blocks)): + self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + + self.union = num_mode is not None + if self.union: + self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) + + if conditioning_embedding_channels is not None: + self.input_hint_block = ControlNetConditioningEmbedding( + conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) + ) + self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + else: + self.input_hint_block = None + self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self): + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @classmethod + def from_transformer( + cls, + transformer, + num_layers: int = 4, + num_single_layers: int = 10, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + load_weights_from_transformer=True, + ): + config = dict(transformer.config) + config["num_layers"] = num_layers + config["num_single_layers"] = num_single_layers + config["attention_head_dim"] = attention_head_dim + config["num_attention_heads"] = num_attention_heads + + controlnet = cls.from_config(config) + + if load_weights_from_transformer: + controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) + controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) + controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) + controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) + controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) + controlnet.single_transformer_blocks.load_state_dict( + transformer.single_transformer_blocks.state_dict(), strict=False + ) + + controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) + + return controlnet + + # Edit 13 Adding attention masking to forward + def forward( + self, + hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + controlnet_mode: torch.Tensor = None, + conditioning_scale: float = 1.0, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + attention_mask: Optional[torch.Tensor] = None, # <-- 1. ADD ARGUMENT HERE + + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + controlnet_cond (`torch.Tensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + controlnet_mode (`torch.Tensor`): + The mode tensor of shape `(batch_size, 1)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + if self.input_hint_block is not None: + controlnet_cond = self.input_hint_block(controlnet_cond) + batch_size, channels, height_pw, width_pw = controlnet_cond.shape + height = height_pw // self.config.patch_size + width = width_pw // self.config.patch_size + controlnet_cond = controlnet_cond.reshape( + batch_size, channels, height, self.config.patch_size, width, self.config.patch_size + ) + controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) + controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) + # add + hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + + #print ('Guidance:', guidance) + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + # edit 19 + else self.time_text_guidance_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if self.union: + # union mode + if controlnet_mode is None: + raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") + # union mode emb + controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) + encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) + txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + block_samples = () + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, # Edit 13 + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, # Edit 13 + + ) + block_samples = block_samples + (hidden_states,) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + single_block_samples = () + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + attention_mask, # <-- 2. PASS MASK TO GRADIENT CHECKPOINTING # Edit 13 + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, # <-- 2. PASS MASK TO BLOCK Edit 13 + + ) + single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + + # controlnet block + controlnet_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + + controlnet_single_block_samples = () + for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): + single_block_sample = controlnet_block(single_block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) + + # scaling + controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] + controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] + + controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples + controlnet_single_block_samples = ( + None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples + ) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (controlnet_block_samples, controlnet_single_block_samples) + + return FluxControlNetOutput( + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + ) +