diff --git "a/backup_pipeline.py" "b/backup_pipeline.py" deleted file mode 100644--- "a/backup_pipeline.py" +++ /dev/null @@ -1,2827 +0,0 @@ -# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This was modied from the control net repo - - -import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel - -import numpy as np -import torch -from transformers import ( - CLIPTextModel, - CLIPTokenizer, - T5EncoderModel, - T5TokenizerFast, -) - -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin -from diffusers.models.autoencoders import AutoencoderKL -### MERGEING THESE ### -# from src.models.transformer import FluxTransformer2DModel -# from src.models.controlnet_flux import FluxControlNetModel -############# - -########################################## -########### ATTENTION MERGE ############## -########################################## - -import torch -from torch import Tensor, FloatTensor -from torch.nn import functional as F -from einops import rearrange -from diffusers.models.attention_processor import Attention -from diffusers.models.embeddings import apply_rotary_emb - -#try: -# from flash_attn_interface import flash_attn_func, flash_attn_qkvpacked_func -#except: -# pass - - -"""def fa3_sdpa( - q, - k, - v, -): - # flash attention 3 sdpa drop-in replacement - q, k, v = [x.permute(0, 2, 1, 3) for x in [q, k, v]] - out = flash_attn_func(q, k, v)[0] - return out.permute(0, 2, 1, 3)""" - -""" -class FluxSingleAttnProcessor3_0: - r"" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - "" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn, - hidden_states: Tensor, - encoder_hidden_states: Tensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> Tensor: - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size, _, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = fa3_sdpa(query, key, value) - hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") - - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states - - -class FluxAttnProcessor3_0: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor3_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn, - hidden_states: FloatTensor, - encoder_hidden_states: FloatTensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> FloatTensor: - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q( - encoder_hidden_states_query_proj - ) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k( - encoder_hidden_states_key_proj - ) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = fa3_sdpa(query, key, value) - hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") - - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states, encoder_hidden_states - - -class FluxFusedFlashAttnProcessor3(object): - """ - True fused QKV Flash Attention 3 processor for Flux models. - Keeps QKV tensors packed through the entire attention computation. - """ - - def __init__(self): - self.flash_attn_qkvpacked_func = None - try: - from flash_attn_interface import flash_attn_qkvpacked_func - - self.flash_attn_qkvpacked_func = flash_attn_qkvpacked_func - except ImportError: - raise ImportError( - "FluxFusedFlashAttnProcessor3 requires flash-attn library. " - "Please see this link for Hopper and Blackwell instructions: https://github.com/bghira/SimpleTuner/blob/main/INSTALL.md#nvidia-hopper--blackwell-follow-up-steps" - ) - - def __call__( - self, - attn, - hidden_states: FloatTensor, - encoder_hidden_states: FloatTensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> FloatTensor: - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - context_input_ndim = ( - encoder_hidden_states.ndim if encoder_hidden_states is not None else None - ) - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size = ( - encoder_hidden_states.shape[0] - if encoder_hidden_states is not None - else hidden_states.shape[0] - ) - seq_len = hidden_states.shape[1] - - # Fused QKV projection - qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim) - inner_dim = qkv.shape[-1] // 3 - head_dim = inner_dim // attn.heads - - # Reshape to packed format: (batch, seq_len, 3, heads, head_dim) - qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) - - # Apply norms if needed (requires temporary unpacking) - if attn.norm_q is not None or attn.norm_k is not None: - q, k, v = qkv.unbind(dim=2) # Each is (batch, seq_len, heads, head_dim) - q = q.transpose(1, 2) # (batch, heads, seq_len, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - if attn.norm_q is not None: - q = attn.norm_q(q) - if attn.norm_k is not None: - k = attn.norm_k(k) - - # Repack: back to (batch, seq_len, 3, heads, head_dim) - qkv = torch.stack( - [q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)], dim=2 - ) - - # Handle encoder states if present - if encoder_hidden_states is not None: - encoder_seq_len = encoder_hidden_states.shape[1] - - # Fused encoder QKV - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - encoder_qkv = encoder_qkv.view( - batch_size, encoder_seq_len, 3, attn.heads, head_dim - ) - - # Apply norms if needed - if attn.norm_added_q is not None or attn.norm_added_k is not None: - enc_q, enc_k, enc_v = encoder_qkv.unbind(dim=2) - enc_q = enc_q.transpose(1, 2) - enc_k = enc_k.transpose(1, 2) - enc_v = enc_v.transpose(1, 2) - - if attn.norm_added_q is not None: - enc_q = attn.norm_added_q(enc_q) - if attn.norm_added_k is not None: - enc_k = attn.norm_added_k(enc_k) - - encoder_qkv = torch.stack( - [ - enc_q.transpose(1, 2), - enc_k.transpose(1, 2), - enc_v.transpose(1, 2), - ], - dim=2, - ) - - # Concatenate along sequence dimension - qkv = torch.cat( - [encoder_qkv, qkv], dim=1 - ) # (batch, encoder_seq + seq, 3, heads, head_dim) - - # Apply RoPE if needed - if image_rotary_emb is not None: - q, k, v = qkv.unbind(dim=2) # Each is (batch, seq_len, heads, head_dim) - - # Transpose to (batch, heads, seq_len, head_dim) for RoPE - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Apply RoPE to q and k - q = apply_rotary_emb(q, image_rotary_emb) - k = apply_rotary_emb(k, image_rotary_emb) - - # Transpose back and repack - qkv = torch.stack( - [q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)], dim=2 - ) - - # Flash Attention 3 with packed QKV - # Input shape: (batch, seq_len, 3, heads, head_dim) - # Output shape: (batch, seq_len, heads, head_dim) - hidden_states = self.flash_attn_qkvpacked_func( - qkv, - causal=False, - # Don't pass num_heads_q for standard MHA - ) - - # Reshape output: (batch, seq_len, heads, head_dim) -> (batch, seq_len, heads * head_dim) - hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(qkv.dtype) - - # Split and process outputs - if encoder_hidden_states is not None: - encoder_seq_len = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :encoder_seq_len] - hidden_states = hidden_states[:, encoder_seq_len:] - - # Output projections - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) # dropout - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - # Reshape if needed - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states, encoder_hidden_states - else: - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - return hidden_states -""" - -class FluxFusedSDPAProcessor: - """ - Fused QKV processor using PyTorch's scaled_dot_product_attention. - Uses fused projections but splits for attention computation. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention" - ) - - def __call__( - self, - attn, - hidden_states: FloatTensor, - encoder_hidden_states: FloatTensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> FloatTensor: - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - context_input_ndim = ( - encoder_hidden_states.ndim if encoder_hidden_states is not None else None - ) - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size = ( - encoder_hidden_states.shape[0] - if encoder_hidden_states is not None - else hidden_states.shape[0] - ) - - # Single attention case (no encoder states) - if encoder_hidden_states is None: - # Use fused QKV projection - qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim) - inner_dim = qkv.shape[-1] // 3 - head_dim = inner_dim // attn.heads - seq_len = hidden_states.shape[1] - - # Split and reshape - qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) - query, key, value = qkv.unbind( - dim=2 - ) # Each is (batch, seq_len, heads, head_dim) - - # Transpose to (batch, heads, seq_len, head_dim) - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # Apply norms if needed - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # SDPA - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - ) - - # Reshape back - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states - - # Joint attention case (with encoder states) - else: - # Process self-attention QKV - qkv = attn.to_qkv(hidden_states) - inner_dim = qkv.shape[-1] // 3 - head_dim = inner_dim // attn.heads - seq_len = hidden_states.shape[1] - - qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) - query, key, value = qkv.unbind(dim=2) - - # Transpose to (batch, heads, seq_len, head_dim) - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # Apply norms if needed - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Process encoder QKV - encoder_seq_len = encoder_hidden_states.shape[1] - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - encoder_qkv = encoder_qkv.view( - batch_size, encoder_seq_len, 3, attn.heads, head_dim - ) - encoder_query, encoder_key, encoder_value = encoder_qkv.unbind(dim=2) - - # Transpose to (batch, heads, seq_len, head_dim) - encoder_query = encoder_query.transpose(1, 2) - encoder_key = encoder_key.transpose(1, 2) - encoder_value = encoder_value.transpose(1, 2) - - # Apply encoder norms if needed - if attn.norm_added_q is not None: - encoder_query = attn.norm_added_q(encoder_query) - if attn.norm_added_k is not None: - encoder_key = attn.norm_added_k(encoder_key) - - # Concatenate encoder and self-attention - query = torch.cat([encoder_query, query], dim=2) - key = torch.cat([encoder_key, key], dim=2) - value = torch.cat([encoder_value, value], dim=2) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # SDPA - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - ) - - # Reshape: (batch, heads, seq_len, head_dim) -> (batch, seq_len, heads * head_dim) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - # Split encoder and self outputs - encoder_hidden_states = hidden_states[:, :encoder_seq_len] - hidden_states = hidden_states[:, encoder_seq_len:] - - # Output projections - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) # dropout - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - # Reshape if needed - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states, encoder_hidden_states - - -class FluxSingleFusedSDPAProcessor: - """ - Fused QKV processor for single attention (no encoder states). - Simpler version for self-attention only blocks. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxSingleFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention" - ) - - def __call__( - self, - attn, - hidden_states: Tensor, - encoder_hidden_states: Tensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> Tensor: - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size, seq_len, _ = hidden_states.shape - - # Use fused QKV projection - qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim) - inner_dim = qkv.shape[-1] // 3 - head_dim = inner_dim // attn.heads - - # Split and reshape in one go - qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) - qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, L, D) – still strided - query, key, value = [ - t.contiguous() for t in qkv.unbind(0) # make each view dense - ] - # Now each is (batch, heads, seq_len, head_dim) - - # Apply norms if needed - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # SDPA - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - # Reshape back - hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states - -################################# -##### TRANSFORMER MERGE ######### -################################# - -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin -from diffusers.models.attention import FeedForward -from diffusers.models.attention_processor import ( - Attention, - AttentionProcessor, -) -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.normalization import ( - AdaLayerNormContinuous, - AdaLayerNormZero, - AdaLayerNormZeroSingle, -) -from diffusers.utils import ( - USE_PEFT_BACKEND, - is_torch_version, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.utils.torch_utils import maybe_allow_in_graph -from diffusers.models.embeddings import ( - CombinedTimestepGuidanceTextProjEmbeddings, - CombinedTimestepTextProjEmbeddings, - FluxPosEmbed, -) - -from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -is_flash_attn_available = False -"""try: - from flash_attn_interface import flash_attn_func - - is_flash_attn_available = True -except: - pass""" - - -class FluxAttnProcessor2_0: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q( - encoder_hidden_states_query_proj - ) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k( - encoder_hidden_states_key_proj - ) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from diffusers.models.embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if attention_mask is not None: - #print ('Attention Used') - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - attention_mask = (attention_mask > 0).bool() - # Edit 17 - match attn dtype to query d-type - attention_mask = attention_mask.to( - device=hidden_states.device, dtype=query.dtype - ) - - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - dropout_p=0.0, - is_causal=False, - attn_mask=attention_mask, - ) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - return hidden_states - - -def expand_flux_attention_mask( - hidden_states: torch.Tensor, - attn_mask: torch.Tensor, -) -> torch.Tensor: - """ - Expand a mask so that the image is included. - """ - bsz = attn_mask.shape[0] - assert bsz == hidden_states.shape[0] - residual_seq_len = hidden_states.shape[1] - mask_seq_len = attn_mask.shape[1] - - expanded_mask = torch.ones(bsz, residual_seq_len) - expanded_mask[:, :mask_seq_len] = attn_mask - - return expanded_mask - - -@maybe_allow_in_graph -class FluxSingleTransformerBlock(nn.Module): - r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. - """ - - def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): - super().__init__() - self.mlp_hidden_dim = int(dim * mlp_ratio) - - self.norm = AdaLayerNormZeroSingle(dim) - self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) - self.act_mlp = nn.GELU(approximate="tanh") - self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - - processor = FluxAttnProcessor2_0() - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - bias=True, - processor=processor, - qk_norm="rms_norm", - eps=1e-6, - pre_only=True, - ) - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - image_rotary_emb=None, - attention_mask: Optional[torch.Tensor] = None, - ): - residual = hidden_states - norm_hidden_states, gate = self.norm(hidden_states, emb=temb) - mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - - if attention_mask is not None: - attention_mask = expand_flux_attention_mask( - hidden_states, - attention_mask, - ) - - attn_output = self.attn( - hidden_states=norm_hidden_states, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, - ) - - hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) - gate = gate.unsqueeze(1) - hidden_states = gate * self.proj_out(hidden_states) - hidden_states = residual + hidden_states - - if hidden_states.dtype == torch.float16: - hidden_states = hidden_states.clip(-65504, 65504) - - return hidden_states - - -@maybe_allow_in_graph -class FluxTransformerBlock(nn.Module): - r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. - """ - - def __init__( - self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6 - ): - super().__init__() - - self.norm1 = AdaLayerNormZero(dim) - - self.norm1_context = AdaLayerNormZero(dim) - - if hasattr(F, "scaled_dot_product_attention"): - processor = FluxAttnProcessor2_0() - else: - raise ValueError( - "The current PyTorch version does not support the `scaled_dot_product_attention` function." - ) - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - added_kv_proj_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - context_pre_only=False, - bias=True, - processor=processor, - qk_norm=qk_norm, - eps=eps, - ) - - self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward( - dim=dim, dim_out=dim, activation_fn="gelu-approximate" - ) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - image_rotary_emb=None, - attention_mask: Optional[torch.Tensor] = None, - ): - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, emb=temb - ) - - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( - self.norm1_context(encoder_hidden_states, emb=temb) - ) - - if attention_mask is not None: - attention_mask = expand_flux_attention_mask( - torch.cat([encoder_hidden_states, hidden_states], dim=1), - attention_mask, - ) - - # Attention. - attention_outputs = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, - ) - if len(attention_outputs) == 2: - attn_output, context_attn_output = attention_outputs - elif len(attention_outputs) == 3: - attn_output, context_attn_output, ip_attn_output = attention_outputs - - # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = ( - norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ) - - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = hidden_states + ff_output - if len(attention_outputs) == 3: - hidden_states = hidden_states + ip_attn_output - - # Process attention outputs for the `encoder_hidden_states`. - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = ( - norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) - + c_shift_mlp[:, None] - ) - - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = ( - encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - ) - - if encoder_hidden_states.dtype == torch.float16: - encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - - return encoder_hidden_states, hidden_states - - -class LibreFluxTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin -): - """ - The Transformer model introduced in Flux. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Parameters: - patch_size (`int`): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. - num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. - num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. - joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. - guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - patch_size: int = 1, - in_channels: int = 64, - num_layers: int = 19, - num_single_layers: int = 38, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, - guidance_embeds: bool = False, - axes_dims_rope: Tuple[int] = (16, 56, 56), - ): - super().__init__() - self.out_channels = in_channels - self.inner_dim = ( - self.config.num_attention_heads * self.config.attention_head_dim - ) - - self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) - text_time_guidance_cls = ( - CombinedTimestepGuidanceTextProjEmbeddings ### 3 input forward (timestep, guidance, pooled_projection) - if guidance_embeds - else CombinedTimestepTextProjEmbeddings #### 2 input forward (timestep, pooled_projection) - ) - self.time_text_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, - pooled_projection_dim=self.config.pooled_projection_dim, - ) - - self.context_embedder = nn.Linear( - self.config.joint_attention_dim, self.inner_dim - ) - self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - ) - for i in range(self.config.num_layers) - ] - ) - - self.single_transformer_blocks = nn.ModuleList( - [ - FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - ) - for i in range(self.config.num_single_layers) - ] - ) - - self.norm_out = AdaLayerNormContinuous( - self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 - ) - self.proj_out = nn.Linear( - self.inner_dim, patch_size * patch_size * self.out_channels, bias=True - ) - - self.gradient_checkpointing = False - # added for users to disable checkpointing every nth step - self.gradient_checkpointing_interval = None - - def set_gradient_checkpointing_interval(self, value: int): - self.gradient_checkpointing_interval = value - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors( - name: str, - module: torch.nn.Module, - processors: Dict[str, AttentionProcessor], - ): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] - ): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_block_samples=None, - controlnet_single_block_samples=None, - return_dict: bool = True, - attention_mask: Optional[torch.Tensor] = None, - controlnet_blocks_repeat: bool = False, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if ( - joint_attention_kwargs is not None - and joint_attention_kwargs.get("scale", None) is not None - ): - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - hidden_states = self.x_embedder(hidden_states) - - timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 - else: - guidance = None - - #print( self.time_text_embed) - temb = ( - self.time_text_embed(timestep,pooled_projections) - # Edit 1 # Charlie NOT NEEDED - UNDONE - if guidance is None - else self.time_text_embed(timestep, guidance, pooled_projections) - ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - if txt_ids.ndim == 3: - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - img_ids = img_ids[0] - - ids = torch.cat((txt_ids, img_ids), dim=0) - - image_rotary_emb = self.pos_embed(ids) - - # IP adapter - if ( - joint_attention_kwargs is not None - and "ip_adapter_image_embeds" in joint_attention_kwargs - ): - ip_adapter_image_embeds = joint_attention_kwargs.pop( - "ip_adapter_image_embeds" - ) - ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) - joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) - - for index_block, block in enumerate(self.transformer_blocks): - if ( - self.training - and self.gradient_checkpointing - and ( - self.gradient_checkpointing_interval is None - or index_block % self.gradient_checkpointing_interval == 0 - ) - ): - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = ( - {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - ) - encoder_hidden_states, hidden_states = ( - torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - attention_mask, - **ckpt_kwargs, - ) - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, - ) - - # controlnet residual - if controlnet_block_samples is not None: - interval_control = len(self.transformer_blocks) / len( - controlnet_block_samples - ) - interval_control = int(np.ceil(interval_control)) - # For Xlabs ControlNet. - if controlnet_blocks_repeat: - hidden_states = ( - hidden_states - + controlnet_block_samples[ - index_block % len(controlnet_block_samples) - ] - ) - else: - hidden_states = ( - hidden_states - + controlnet_block_samples[index_block // interval_control] - ) - - # Flux places the text tokens in front of the image tokens in the - # sequence. - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - for index_block, block in enumerate(self.single_transformer_blocks): - if ( - self.training - and self.gradient_checkpointing - or ( - self.gradient_checkpointing_interval is not None - and index_block % self.gradient_checkpointing_interval == 0 - ) - ): - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = ( - {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - temb, - image_rotary_emb, - attention_mask, - **ckpt_kwargs, - ) - - else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, - ) - - # controlnet residual - if controlnet_single_block_samples is not None: - interval_control = len(self.single_transformer_blocks) / len( - controlnet_single_block_samples - ) - interval_control = int(np.ceil(interval_control)) - hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( - hidden_states[:, encoder_hidden_states.shape[1] :, ...] - + controlnet_single_block_samples[index_block // interval_control] - ) - - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - - hidden_states = self.norm_out(hidden_states, temb) - output = self.proj_out(hidden_states) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) - -#################################### -##### CONTROL NET MODEL MERGE ###### -#################################### - - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import PeftAdapterMixin -from diffusers.models.attention_processor import AttentionProcessor -from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, zero_module -from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed -from diffusers.models.modeling_outputs import Transformer2DModelOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class FluxControlNetOutput(BaseOutput): - controlnet_block_samples: Tuple[torch.Tensor] - controlnet_single_block_samples: Tuple[torch.Tensor] - - -class LibreFluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - patch_size: int = 1, - in_channels: int = 64, - num_layers: int = 19, - num_single_layers: int = 38, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, - guidance_embeds: bool = False, - axes_dims_rope: List[int] = [16, 56, 56], - num_mode: int = None, - conditioning_embedding_channels: int = None, - ): - super().__init__() - self.out_channels = in_channels - self.inner_dim = num_attention_heads * attention_head_dim - - self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) - - # edit 19 - #text_time_guidance_cls = ( - # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings - #) - - text_time_guidance_cls = CombinedTimestepGuidanceTextProjEmbeddings - text_time_cls = CombinedTimestepTextProjEmbeddings - - self.time_text_embed = text_time_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - self.time_text_guidance_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - - self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) - self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for i in range(num_layers) - ] - ) - - self.single_transformer_blocks = nn.ModuleList( - [ - FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for i in range(num_single_layers) - ] - ) - - # controlnet_blocks - self.controlnet_blocks = nn.ModuleList([]) - for _ in range(len(self.transformer_blocks)): - self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) - - self.controlnet_single_blocks = nn.ModuleList([]) - for _ in range(len(self.single_transformer_blocks)): - self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) - - self.union = num_mode is not None - if self.union: - self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) - - if conditioning_embedding_channels is not None: - self.input_hint_block = ControlNetConditioningEmbedding( - conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) - ) - self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) - else: - self.input_hint_block = None - self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) - - self.gradient_checkpointing = False - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self): - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - @classmethod - def from_transformer( - cls, - transformer, - num_layers: int = 4, - num_single_layers: int = 10, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - load_weights_from_transformer=True, - ): - config = dict(transformer.config) - config["num_layers"] = num_layers - config["num_single_layers"] = num_single_layers - config["attention_head_dim"] = attention_head_dim - config["num_attention_heads"] = num_attention_heads - - controlnet = cls.from_config(config) - - if load_weights_from_transformer: - controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) - controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) - controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) - controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) - controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) - controlnet.single_transformer_blocks.load_state_dict( - transformer.single_transformer_blocks.state_dict(), strict=False - ) - - controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) - - return controlnet - - # Edit 13 Adding attention masking to forward - def forward( - self, - hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - controlnet_mode: torch.Tensor = None, - conditioning_scale: float = 1.0, - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - attention_mask: Optional[torch.Tensor] = None, # <-- 1. ADD ARGUMENT HERE - - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - controlnet_cond (`torch.Tensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - controlnet_mode (`torch.Tensor`): - The mode tensor of shape `(batch_size, 1)`. - conditioning_scale (`float`, defaults to `1.0`): - The scale factor for ControlNet outputs. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - hidden_states = self.x_embedder(hidden_states) - - if self.input_hint_block is not None: - controlnet_cond = self.input_hint_block(controlnet_cond) - batch_size, channels, height_pw, width_pw = controlnet_cond.shape - height = height_pw // self.config.patch_size - width = width_pw // self.config.patch_size - controlnet_cond = controlnet_cond.reshape( - batch_size, channels, height, self.config.patch_size, width, self.config.patch_size - ) - controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) - controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) - # add - hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) - - timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 - else: - guidance = None - - #print ('Guidance:', guidance) - temb = ( - self.time_text_embed(timestep, pooled_projections) - if guidance is None - # edit 19 - else self.time_text_guidance_embed(timestep, guidance, pooled_projections) - ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - if self.union: - # union mode - if controlnet_mode is None: - raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") - # union mode emb - controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) - encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) - txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) - - if txt_ids.ndim == 3: - logger.warning( - "Passing `txt_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - logger.warning( - "Passing `img_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - img_ids = img_ids[0] - - ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = self.pos_embed(ids) - - block_samples = () - for index_block, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - attention_mask, # Edit 13 - **ckpt_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, # Edit 13 - - ) - block_samples = block_samples + (hidden_states,) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - single_block_samples = () - for index_block, block in enumerate(self.single_transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - temb, - image_rotary_emb, - attention_mask, # <-- 2. PASS MASK TO GRADIENT CHECKPOINTING # Edit 13 - **ckpt_kwargs, - ) - - else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, # <-- 2. PASS MASK TO BLOCK Edit 13 - - ) - single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) - - # controlnet block - controlnet_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): - block_sample = controlnet_block(block_sample) - controlnet_block_samples = controlnet_block_samples + (block_sample,) - - controlnet_single_block_samples = () - for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): - single_block_sample = controlnet_block(single_block_sample) - controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) - - # scaling - controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] - controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] - - controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples - controlnet_single_block_samples = ( - None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples - ) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (controlnet_block_samples, controlnet_single_block_samples) - - return FluxControlNetOutput( - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - ) - - -#################################### -##### ACTUAL PIPELINE STUFF ######## -#################################### - - -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from diffusers.utils import ( - USE_PEFT_BACKEND, - is_torch_xla_available, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - -# TODO(Chris): why won't this emit messages at the INFO level??? -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers.utils import load_image - >>> from diffusers import FluxControlNetPipeline - >>> from diffusers import FluxControlNetModel - - >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny" - >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) - >>> pipe = FluxControlNetPipeline.from_pretrained( - ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 - ... ) - >>> pipe.to("cuda") - >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") - >>> prompt = "A girl in city, 25 years old, cool, futuristic" - >>> image = pipe( - ... prompt, - ... control_image=control_image, - ... controlnet_conditioning_scale=0.6, - ... num_inference_steps=28, - ... guidance_scale=3.5, - ... ).images[0] - >>> image.save("flux.png") - ``` -""" - -def _maybe_to(x: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): - if device is None and dtype is None: - return x - need_dev = device is not None and str(getattr(x, "device", None)) != str(device) - need_dt = dtype is not None and getattr(x, "dtype", None) != dtype - return x.to(device=device if need_dev else x.device, dtype=dtype if need_dt else x.dtype) if (need_dev or need_dt) else x - - -# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.16, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class LibreFluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): - r""" - The Flux pipeline for text-to-image generation. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Args: - transformer ([`FluxTransformer2DModel`]): - Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_encoder_2 ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`T5TokenizerFast`): - Second Tokenizer of class - [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). - """ - - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = [] - _callback_tensor_inputs = ["latents", "prompt_embeds"] - - def __init__( - self, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - text_encoder_2: T5EncoderModel, - tokenizer_2: T5TokenizerFast, - transformer: LibreFluxTransformer2DModel, - controlnet: Union[ - LibreFluxControlNetModel, List[LibreFluxControlNetModel], Tuple[LibreFluxControlNetModel], - ], - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - transformer=transformer, - scheduler=scheduler, - controlnet=controlnet, - ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 - ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 - ) - self.default_sample_size = 64 - - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(self.text_encoder_2.device), output_hidden_states=False)[0] - #prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings for each generation per prompt - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # ADD THIS: Get the attention mask and repeat it for each image - prompt_attention_mask = text_inputs.attention_mask.to(device=device, dtype=dtype) - prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) - - # ADD THIS: Return the attention mask - return prompt_embeds, prompt_attention_mask - - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.device), output_hidden_states=False) - #prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - device = device or self._execution_device - - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - - # ADD THIS: Initialize mask and capture it from the T5 embedder - prompt_attention_mask = None - prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - unscale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - unscale_lora_layers(self.text_encoder_2, lora_scale) - - # FIX: Get batch_size and create text_ids with the correct shape - batch_size = prompt_embeds.shape[0] - dtype = self.transformer.dtype - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask - - def check_inputs( - self, - prompt, - prompt_2, - height, - width, - prompt_embeds=None, - pooled_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - # FIX: Correctly creates batched image IDs - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] - - latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1, 1) - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape[1:] - - latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - height = height // vae_scale_factor - width = width // vae_scale_factor - - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) - - return latents - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) - - shape = (batch_size, num_channels_latents, height, width) - - if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - - return latents, latent_image_ids - - # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - if isinstance(image, torch.Tensor): - pass - else: - image = self.image_processor.preprocess(image, height=height, width=width) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 28, - timesteps: List[int] = None, - guidance_scale: float = 7.0, - control_image: PipelineImageInput = None, - control_mode: Optional[Union[int, List[int]]] = None, - control_image_undo_centering: bool = False, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, - negative_prompt: Optional[Union[str, List[str]]] = "", - negative_prompt_2: Optional[Union[str, List[str]]] = "", - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: - `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted - as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or - width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, - images must be passed as a list such that each element of the list can be correctly batched for input - to a single ControlNet. - controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set - the corresponding scale as a list. - control_mode (`int` or `List[int]`,, *optional*, defaults to None): - The control mode when applying ControlNet-Union. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - - Examples: - - Returns: - [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. - """ - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - height, - width, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - max_sequence_length=max_sequence_length, - ) - - self._guidance_scale = guidance_scale - self._joint_attention_kwargs = joint_attention_kwargs - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - dtype = self.transformer.dtype - - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - # 💡 ADD THIS: Capture the attention_mask from encode_prompt - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - attention_mask, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - - # ✨ FIX: Encode negative prompts for CFG - do_classifier_free_guidance = guidance_scale > 1.0 - if do_classifier_free_guidance: - if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - (negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids, negative_attention_mask) = self.encode_prompt( - prompt=negative_prompt, prompt_2=negative_prompt_2, device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, lora_scale=lora_scale, - ) - - - # 3. Prepare control image - num_channels_latents = self.transformer.config.in_channels // 4 - - if type(self.controlnet) == FullyShardedDataParallel: - inner_module = self.controlnet._fsdp_wrapped_module - else: - inner_module = self.controlnet - - if isinstance(inner_module, LibreFluxControlNetModel): - control_image = self.prepare_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=dtype, - ) - - if control_image_undo_centering: - if not self.image_processor.do_normalize: - raise ValueError( - "`control_image_undo_centering` only makes sense if `do_normalize==True` in the image processor" - ) - control_image = control_image*0.5 + 0.5 - - height, width = control_image.shape[-2:] - - #logger.warning( - # f"pipeline_flux_controlnet, control_image: {control_image.min()} {control_image.max()}" - #) - - # vae encode - control_image = _maybe_to(control_image, device=self.vae.device) - control_image = self.vae.encode(control_image).latent_dist.sample() - control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - control_image = _maybe_to(control_image, device=device) - # pack - height_control_image, width_control_image = control_image.shape[2:] - control_image = self._pack_latents( - control_image, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) - - # set control mode - if control_mode is not None: - control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) - - - # set control mode - control_mode_ = [] - if isinstance(control_mode, list): - for cmode in control_mode: - if cmode is None: - control_mode_.append(-1) - else: - control_mode_.append(cmode) - control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) - - # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = latents.shape[1] - mu = calculate_shift( - image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - timesteps, - sigmas, - mu=mu, - ) - - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # 6. Denoising loop - target_device = self.transformer.device - self.controlnet.to(target_device) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - - # FIX: BATCH INPUTS FOR CFG - if do_classifier_free_guidance: - latent_model_input = torch.cat([latents] * 2) - current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - current_pooled_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds]) - current_attention_mask = torch.cat([negative_attention_mask, attention_mask]) - current_text_ids = torch.cat([negative_text_ids, text_ids]) - current_img_ids = torch.cat([latent_image_ids] * 2) - current_control_image = torch.cat([control_image] * 2) if isinstance(control_image, torch.Tensor) else [torch.cat([c_img] * 2) for c_img in control_image] - else: - latent_model_input = latents - current_prompt_embeds = prompt_embeds - current_pooled_embeds = pooled_prompt_embeds - current_attention_mask = attention_mask - current_text_ids = text_ids - current_img_ids = latent_image_ids - current_control_image = control_image - - # FIX: Integrate with device handling - target_device = self.transformer.device - - # Move all inputs to the target device - latent_model_input = _maybe_to(latent_model_input, device=target_device) - current_prompt_embeds = _maybe_to(current_prompt_embeds, device=target_device) - current_pooled_embeds = _maybe_to(current_pooled_embeds, device=target_device) - current_attention_mask = _maybe_to(current_attention_mask, device=target_device) - current_text_ids = _maybe_to(current_text_ids, device=target_device) - current_img_ids = _maybe_to(current_img_ids, device=target_device) - if isinstance(current_control_image, torch.Tensor): - current_control_image = _maybe_to(current_control_image, device=target_device) - else: - current_control_image = [ _maybe_to(c, device=target_device) for c in current_control_image ] - control_mode = _maybe_to(control_mode, device=target_device) if control_mode is not None else None - - t_model = t.expand(latent_model_input.shape[0]).to(target_device) - - - # Model calls - controlnet_block_samples, controlnet_single_block_samples = self.controlnet( - hidden_states=latent_model_input, - controlnet_cond=current_control_image, - controlnet_mode=control_mode, - conditioning_scale=controlnet_conditioning_scale, - timestep=(t_model / 1000), - guidance=None, - pooled_projections=current_pooled_embeds, - encoder_hidden_states=current_prompt_embeds, - attention_mask=current_attention_mask, - txt_ids=current_text_ids, - img_ids=current_img_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False - ) - - controlnet_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_block_samples] - controlnet_single_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_single_block_samples] - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=(t_model / 1000), - guidance=None, - pooled_projections=current_pooled_embeds, - encoder_hidden_states=current_prompt_embeds, - attention_mask=current_attention_mask, - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - txt_ids=current_text_ids, - img_ids=current_img_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False - )[0] - - # FIX: Apply CFG formula - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - - ## Probably not needed - #noise_pred = noise_pred.to(latents.device) - - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if output_type == "latent": - image = latents - - else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - - latents = _maybe_to(latents, device=self.vae.device) - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return FluxPipelineOutput(images=image) \ No newline at end of file