|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import html |
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import regex as re |
|
|
import torch |
|
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast |
|
|
|
|
|
from ...configuration_utils import FrozenDict |
|
|
from ...image_processor import VaeImageProcessor |
|
|
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin |
|
|
from ...models import AutoencoderKL |
|
|
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers |
|
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState |
|
|
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam |
|
|
from .modular_pipeline import FluxModularPipeline |
|
|
|
|
|
|
|
|
if is_ftfy_available(): |
|
|
import ftfy |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def basic_clean(text): |
|
|
text = ftfy.fix_text(text) |
|
|
text = html.unescape(html.unescape(text)) |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
def whitespace_clean(text): |
|
|
text = re.sub(r"\s+", " ", text) |
|
|
text = text.strip() |
|
|
return text |
|
|
|
|
|
|
|
|
def prompt_clean(text): |
|
|
text = whitespace_clean(basic_clean(text)) |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
def retrieve_latents( |
|
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
|
|
): |
|
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
|
|
return encoder_output.latent_dist.sample(generator) |
|
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
|
|
return encoder_output.latent_dist.mode() |
|
|
elif hasattr(encoder_output, "latents"): |
|
|
return encoder_output.latents |
|
|
else: |
|
|
raise AttributeError("Could not access latents of provided encoder_output") |
|
|
|
|
|
|
|
|
class FluxVaeEncoderStep(ModularPipelineBlocks): |
|
|
model_name = "flux" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Vae Encoder step that encode the input image into a latent representation" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("vae", AutoencoderKL), |
|
|
ComponentSpec( |
|
|
"image_processor", |
|
|
VaeImageProcessor, |
|
|
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}), |
|
|
default_creation_method="from_config", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("image", required=True), |
|
|
InputParam("height"), |
|
|
InputParam("width"), |
|
|
InputParam("generator"), |
|
|
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), |
|
|
InputParam( |
|
|
"preprocess_kwargs", |
|
|
type_hint=Optional[dict], |
|
|
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"image_latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="The latents representing the reference image for image-to-image/inpainting generation", |
|
|
) |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator): |
|
|
if isinstance(generator, list): |
|
|
image_latents = [ |
|
|
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) |
|
|
] |
|
|
image_latents = torch.cat(image_latents, dim=0) |
|
|
else: |
|
|
image_latents = retrieve_latents(vae.encode(image), generator=generator) |
|
|
|
|
|
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor |
|
|
|
|
|
return image_latents |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} |
|
|
block_state.device = components._execution_device |
|
|
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype |
|
|
|
|
|
block_state.image = components.image_processor.preprocess( |
|
|
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs |
|
|
) |
|
|
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) |
|
|
|
|
|
block_state.batch_size = block_state.image.shape[0] |
|
|
|
|
|
|
|
|
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: |
|
|
raise ValueError( |
|
|
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" |
|
|
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." |
|
|
) |
|
|
|
|
|
block_state.image_latents = self._encode_vae_image( |
|
|
components.vae, image=block_state.image, generator=block_state.generator |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|
|
|
|
|
|
class FluxTextEncoderStep(ModularPipelineBlocks): |
|
|
model_name = "flux" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Text Encoder step that generate text_embeddings to guide the video generation" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("text_encoder", CLIPTextModel), |
|
|
ComponentSpec("tokenizer", CLIPTokenizer), |
|
|
ComponentSpec("text_encoder_2", T5EncoderModel), |
|
|
ComponentSpec("tokenizer_2", T5TokenizerFast), |
|
|
] |
|
|
|
|
|
@property |
|
|
def expected_configs(self) -> List[ConfigSpec]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("prompt"), |
|
|
InputParam("prompt_2"), |
|
|
InputParam("joint_attention_kwargs"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="text embeddings used to guide the image generation", |
|
|
), |
|
|
OutputParam( |
|
|
"pooled_prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="pooled text embeddings used to guide the image generation", |
|
|
), |
|
|
OutputParam( |
|
|
"text_ids", |
|
|
type_hint=torch.Tensor, |
|
|
description="ids from the text sequence for RoPE", |
|
|
), |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def check_inputs(block_state): |
|
|
for prompt in [block_state.prompt, block_state.prompt_2]: |
|
|
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
|
|
raise ValueError(f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}") |
|
|
|
|
|
@staticmethod |
|
|
def _get_t5_prompt_embeds( |
|
|
components, |
|
|
prompt: Union[str, List[str]], |
|
|
num_images_per_prompt: int, |
|
|
max_sequence_length: int, |
|
|
device: torch.device, |
|
|
): |
|
|
dtype = components.text_encoder_2.dtype |
|
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
batch_size = len(prompt) |
|
|
|
|
|
if isinstance(components, TextualInversionLoaderMixin): |
|
|
prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2) |
|
|
|
|
|
text_inputs = components.tokenizer_2( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=max_sequence_length, |
|
|
truncation=True, |
|
|
return_length=False, |
|
|
return_overflowing_tokens=False, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids |
|
|
|
|
|
untruncated_ids = components.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids |
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): |
|
|
removed_text = components.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) |
|
|
logger.warning( |
|
|
"The following part of your input was truncated because `max_sequence_length` is set to " |
|
|
f" {max_sequence_length} tokens: {removed_text}" |
|
|
) |
|
|
|
|
|
prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] |
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
_, seq_len, _ = prompt_embeds.shape |
|
|
|
|
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
|
|
|
return prompt_embeds |
|
|
|
|
|
@staticmethod |
|
|
def _get_clip_prompt_embeds( |
|
|
components, |
|
|
prompt: Union[str, List[str]], |
|
|
num_images_per_prompt: int, |
|
|
device: torch.device, |
|
|
): |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
batch_size = len(prompt) |
|
|
|
|
|
if isinstance(components, TextualInversionLoaderMixin): |
|
|
prompt = components.maybe_convert_prompt(prompt, components.tokenizer) |
|
|
|
|
|
text_inputs = components.tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=components.tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_overflowing_tokens=False, |
|
|
return_length=False, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
text_input_ids = text_inputs.input_ids |
|
|
tokenizer_max_length = components.tokenizer.model_max_length |
|
|
untruncated_ids = components.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): |
|
|
removed_text = components.tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) |
|
|
logger.warning( |
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to" |
|
|
f" {tokenizer_max_length} tokens: {removed_text}" |
|
|
) |
|
|
prompt_embeds = components.text_encoder(text_input_ids.to(device), output_hidden_states=False) |
|
|
|
|
|
|
|
|
prompt_embeds = prompt_embeds.pooler_output |
|
|
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device) |
|
|
|
|
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) |
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) |
|
|
|
|
|
return prompt_embeds |
|
|
|
|
|
@staticmethod |
|
|
def encode_prompt( |
|
|
components, |
|
|
prompt: Union[str, List[str]], |
|
|
prompt_2: Union[str, List[str]], |
|
|
device: Optional[torch.device] = None, |
|
|
num_images_per_prompt: int = 1, |
|
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
|
max_sequence_length: int = 512, |
|
|
lora_scale: Optional[float] = None, |
|
|
): |
|
|
r""" |
|
|
Encodes the prompt into text encoder hidden states. |
|
|
|
|
|
Args: |
|
|
prompt (`str` or `List[str]`, *optional*): |
|
|
prompt to be encoded |
|
|
prompt_2 (`str` or `List[str]`, *optional*): |
|
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
|
|
used in all text-encoders |
|
|
device: (`torch.device`): |
|
|
torch device |
|
|
num_images_per_prompt (`int`): |
|
|
number of images that should be generated per prompt |
|
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
|
provided, text embeddings will be generated from `prompt` input argument. |
|
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
|
|
If not provided, pooled text embeddings will be generated from `prompt` input argument. |
|
|
lora_scale (`float`, *optional*): |
|
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. |
|
|
""" |
|
|
device = device or components._execution_device |
|
|
|
|
|
|
|
|
|
|
|
if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin): |
|
|
components._lora_scale = lora_scale |
|
|
|
|
|
|
|
|
if components.text_encoder is not None and USE_PEFT_BACKEND: |
|
|
scale_lora_layers(components.text_encoder, lora_scale) |
|
|
if components.text_encoder_2 is not None and USE_PEFT_BACKEND: |
|
|
scale_lora_layers(components.text_encoder_2, lora_scale) |
|
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
|
|
|
if prompt_embeds is None: |
|
|
prompt_2 = prompt_2 or prompt |
|
|
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 |
|
|
|
|
|
|
|
|
pooled_prompt_embeds = FluxTextEncoderStep._get_clip_prompt_embeds( |
|
|
components, |
|
|
prompt=prompt, |
|
|
device=device, |
|
|
num_images_per_prompt=num_images_per_prompt, |
|
|
) |
|
|
prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds( |
|
|
components, |
|
|
prompt=prompt_2, |
|
|
num_images_per_prompt=num_images_per_prompt, |
|
|
max_sequence_length=max_sequence_length, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
if components.text_encoder is not None: |
|
|
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
|
|
|
|
unscale_lora_layers(components.text_encoder, lora_scale) |
|
|
|
|
|
if components.text_encoder_2 is not None: |
|
|
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
|
|
|
|
unscale_lora_layers(components.text_encoder_2, lora_scale) |
|
|
|
|
|
dtype = components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16 |
|
|
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) |
|
|
|
|
|
return prompt_embeds, pooled_prompt_embeds, text_ids |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: |
|
|
|
|
|
block_state = self.get_block_state(state) |
|
|
self.check_inputs(block_state) |
|
|
|
|
|
block_state.device = components._execution_device |
|
|
|
|
|
|
|
|
block_state.text_encoder_lora_scale = ( |
|
|
block_state.joint_attention_kwargs.get("scale", None) |
|
|
if block_state.joint_attention_kwargs is not None |
|
|
else None |
|
|
) |
|
|
(block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt( |
|
|
components, |
|
|
prompt=block_state.prompt, |
|
|
prompt_2=None, |
|
|
prompt_embeds=None, |
|
|
pooled_prompt_embeds=None, |
|
|
device=block_state.device, |
|
|
num_images_per_prompt=1, |
|
|
lora_scale=block_state.text_encoder_lora_scale, |
|
|
) |
|
|
|
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|