Spaces:
Running on Zero
Running on Zero
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List | |
| import torch | |
| from ...pipelines import FluxPipeline | |
| from ...utils import logging | |
| from ..modular_pipeline import ModularPipelineBlocks, PipelineState | |
| from ..modular_pipeline_utils import InputParam, OutputParam | |
| # TODO: consider making these common utilities for modular if they are not pipeline-specific. | |
| from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size | |
| from .modular_pipeline import FluxModularPipeline | |
| logger = logging.get_logger(__name__) | |
| class FluxTextInputStep(ModularPipelineBlocks): | |
| model_name = "flux" | |
| def description(self) -> str: | |
| return ( | |
| "Text input processing step that standardizes text embeddings for the pipeline.\n" | |
| "This step:\n" | |
| " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" | |
| " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" | |
| ) | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam("num_images_per_prompt", default=1), | |
| InputParam( | |
| "prompt_embeds", | |
| required=True, | |
| kwargs_type="denoiser_input_fields", | |
| type_hint=torch.Tensor, | |
| description="Pre-generated text embeddings. Can be generated from text_encoder step.", | |
| ), | |
| InputParam( | |
| "pooled_prompt_embeds", | |
| kwargs_type="denoiser_input_fields", | |
| type_hint=torch.Tensor, | |
| description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", | |
| ), | |
| # TODO: support negative embeddings? | |
| ] | |
| def intermediate_outputs(self) -> List[str]: | |
| return [ | |
| OutputParam( | |
| "batch_size", | |
| type_hint=int, | |
| description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", | |
| ), | |
| OutputParam( | |
| "dtype", | |
| type_hint=torch.dtype, | |
| description="Data type of model tensor inputs (determined by `prompt_embeds`)", | |
| ), | |
| OutputParam( | |
| "prompt_embeds", | |
| type_hint=torch.Tensor, | |
| kwargs_type="denoiser_input_fields", | |
| description="text embeddings used to guide the image generation", | |
| ), | |
| OutputParam( | |
| "pooled_prompt_embeds", | |
| type_hint=torch.Tensor, | |
| kwargs_type="denoiser_input_fields", | |
| description="pooled text embeddings used to guide the image generation", | |
| ), | |
| # TODO: support negative embeddings? | |
| ] | |
| def check_inputs(self, components, block_state): | |
| if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None: | |
| if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]: | |
| raise ValueError( | |
| "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but" | |
| f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`" | |
| f" {block_state.pooled_prompt_embeds.shape}." | |
| ) | |
| def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: | |
| # TODO: consider adding negative embeddings? | |
| block_state = self.get_block_state(state) | |
| self.check_inputs(components, block_state) | |
| block_state.batch_size = block_state.prompt_embeds.shape[0] | |
| block_state.dtype = block_state.prompt_embeds.dtype | |
| _, seq_len, _ = block_state.prompt_embeds.shape | |
| block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) | |
| block_state.prompt_embeds = block_state.prompt_embeds.view( | |
| block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 | |
| ) | |
| pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) | |
| block_state.pooled_prompt_embeds = pooled_prompt_embeds.view( | |
| block_state.batch_size * block_state.num_images_per_prompt, -1 | |
| ) | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| # Adapted from `QwenImageAdditionalInputsStep` | |
| class FluxInputsDynamicStep(ModularPipelineBlocks): | |
| model_name = "flux" | |
| def __init__( | |
| self, | |
| image_latent_inputs: List[str] = ["image_latents"], | |
| additional_batch_inputs: List[str] = [], | |
| ): | |
| if not isinstance(image_latent_inputs, list): | |
| image_latent_inputs = [image_latent_inputs] | |
| if not isinstance(additional_batch_inputs, list): | |
| additional_batch_inputs = [additional_batch_inputs] | |
| self._image_latent_inputs = image_latent_inputs | |
| self._additional_batch_inputs = additional_batch_inputs | |
| super().__init__() | |
| def description(self) -> str: | |
| # Functionality section | |
| summary_section = ( | |
| "Input processing step that:\n" | |
| " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n" | |
| " 2. For additional batch inputs: Expands batch dimensions to match final batch size" | |
| ) | |
| # Inputs info | |
| inputs_info = "" | |
| if self._image_latent_inputs or self._additional_batch_inputs: | |
| inputs_info = "\n\nConfigured inputs:" | |
| if self._image_latent_inputs: | |
| inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" | |
| if self._additional_batch_inputs: | |
| inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" | |
| # Placement guidance | |
| placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." | |
| return summary_section + inputs_info + placement_section | |
| def inputs(self) -> List[InputParam]: | |
| inputs = [ | |
| InputParam(name="num_images_per_prompt", default=1), | |
| InputParam(name="batch_size", required=True), | |
| InputParam(name="height"), | |
| InputParam(name="width"), | |
| ] | |
| # Add image latent inputs | |
| for image_latent_input_name in self._image_latent_inputs: | |
| inputs.append(InputParam(name=image_latent_input_name)) | |
| # Add additional batch inputs | |
| for input_name in self._additional_batch_inputs: | |
| inputs.append(InputParam(name=input_name)) | |
| return inputs | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam(name="image_height", type_hint=int, description="The height of the image latents"), | |
| OutputParam(name="image_width", type_hint=int, description="The width of the image latents"), | |
| ] | |
| def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| # Process image latent inputs (height/width calculation, patchify, and batch expansion) | |
| for image_latent_input_name in self._image_latent_inputs: | |
| image_latent_tensor = getattr(block_state, image_latent_input_name) | |
| if image_latent_tensor is None: | |
| continue | |
| # 1. Calculate height/width from latents | |
| height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) | |
| block_state.height = block_state.height or height | |
| block_state.width = block_state.width or width | |
| if not hasattr(block_state, "image_height"): | |
| block_state.image_height = height | |
| if not hasattr(block_state, "image_width"): | |
| block_state.image_width = width | |
| # 2. Patchify the image latent tensor | |
| # TODO: Implement patchifier for Flux. | |
| latent_height, latent_width = image_latent_tensor.shape[2:] | |
| image_latent_tensor = FluxPipeline._pack_latents( | |
| image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width | |
| ) | |
| # 3. Expand batch size | |
| image_latent_tensor = repeat_tensor_to_batch_size( | |
| input_name=image_latent_input_name, | |
| input_tensor=image_latent_tensor, | |
| num_images_per_prompt=block_state.num_images_per_prompt, | |
| batch_size=block_state.batch_size, | |
| ) | |
| setattr(block_state, image_latent_input_name, image_latent_tensor) | |
| # Process additional batch inputs (only batch expansion) | |
| for input_name in self._additional_batch_inputs: | |
| input_tensor = getattr(block_state, input_name) | |
| if input_tensor is None: | |
| continue | |
| # Only expand batch size | |
| input_tensor = repeat_tensor_to_batch_size( | |
| input_name=input_name, | |
| input_tensor=input_tensor, | |
| num_images_per_prompt=block_state.num_images_per_prompt, | |
| batch_size=block_state.batch_size, | |
| ) | |
| setattr(block_state, input_name, input_tensor) | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class FluxKontextInputsDynamicStep(FluxInputsDynamicStep): | |
| model_name = "flux-kontext" | |
| def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| # Process image latent inputs (height/width calculation, patchify, and batch expansion) | |
| for image_latent_input_name in self._image_latent_inputs: | |
| image_latent_tensor = getattr(block_state, image_latent_input_name) | |
| if image_latent_tensor is None: | |
| continue | |
| # 1. Calculate height/width from latents | |
| # Unlike the `FluxInputsDynamicStep`, we don't overwrite the `block.height` and `block.width` | |
| height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) | |
| if not hasattr(block_state, "image_height"): | |
| block_state.image_height = height | |
| if not hasattr(block_state, "image_width"): | |
| block_state.image_width = width | |
| # 2. Patchify the image latent tensor | |
| # TODO: Implement patchifier for Flux. | |
| latent_height, latent_width = image_latent_tensor.shape[2:] | |
| image_latent_tensor = FluxPipeline._pack_latents( | |
| image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width | |
| ) | |
| # 3. Expand batch size | |
| image_latent_tensor = repeat_tensor_to_batch_size( | |
| input_name=image_latent_input_name, | |
| input_tensor=image_latent_tensor, | |
| num_images_per_prompt=block_state.num_images_per_prompt, | |
| batch_size=block_state.batch_size, | |
| ) | |
| setattr(block_state, image_latent_input_name, image_latent_tensor) | |
| # Process additional batch inputs (only batch expansion) | |
| for input_name in self._additional_batch_inputs: | |
| input_tensor = getattr(block_state, input_name) | |
| if input_tensor is None: | |
| continue | |
| # Only expand batch size | |
| input_tensor = repeat_tensor_to_batch_size( | |
| input_name=input_name, | |
| input_tensor=input_tensor, | |
| num_images_per_prompt=block_state.num_images_per_prompt, | |
| batch_size=block_state.batch_size, | |
| ) | |
| setattr(block_state, input_name, input_tensor) | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class FluxKontextSetResolutionStep(ModularPipelineBlocks): | |
| model_name = "flux-kontext" | |
| def description(self): | |
| return ( | |
| "Determines the height and width to be used during the subsequent computations.\n" | |
| "It should always be placed _before_ the latent preparation step." | |
| ) | |
| def inputs(self) -> List[InputParam]: | |
| inputs = [ | |
| InputParam(name="height"), | |
| InputParam(name="width"), | |
| InputParam(name="max_area", type_hint=int, default=1024**2), | |
| ] | |
| return inputs | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam(name="height", type_hint=int, description="The height of the initial noisy latents"), | |
| OutputParam(name="width", type_hint=int, description="The width of the initial noisy latents"), | |
| ] | |
| def check_inputs(height, width, vae_scale_factor): | |
| if height is not None and height % (vae_scale_factor * 2) != 0: | |
| raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") | |
| if width is not None and width % (vae_scale_factor * 2) != 0: | |
| raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") | |
| def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| height = block_state.height or components.default_height | |
| width = block_state.width or components.default_width | |
| self.check_inputs(height, width, components.vae_scale_factor) | |
| original_height, original_width = height, width | |
| max_area = block_state.max_area | |
| aspect_ratio = width / height | |
| width = round((max_area * aspect_ratio) ** 0.5) | |
| height = round((max_area / aspect_ratio) ** 0.5) | |
| multiple_of = components.vae_scale_factor * 2 | |
| width = width // multiple_of * multiple_of | |
| height = height // multiple_of * multiple_of | |
| if height != original_height or width != original_width: | |
| logger.warning( | |
| f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." | |
| ) | |
| block_state.height = height | |
| block_state.width = width | |
| self.set_block_state(state, block_state) | |
| return components, state | |