| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any |
|
|
| import numpy as np |
| import PIL |
| import torch |
|
|
| from ...configuration_utils import FrozenDict |
| from ...image_processor import VaeImageProcessor |
| from ...models import AutoencoderKL |
| from ...utils import deprecate, logging |
| from ..modular_pipeline import ( |
| ModularPipelineBlocks, |
| PipelineState, |
| ) |
| from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class StableDiffusionXLDecodeStep(ModularPipelineBlocks): |
| model_name = "stable-diffusion-xl" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ |
| ComponentSpec("vae", AutoencoderKL), |
| ComponentSpec( |
| "image_processor", |
| VaeImageProcessor, |
| config=FrozenDict({"vae_scale_factor": 8}), |
| default_creation_method="from_config", |
| ), |
| ] |
|
|
| @property |
| def description(self) -> str: |
| return "Step that decodes the denoised latents into images" |
|
|
| @property |
| def inputs(self) -> list[tuple[str, Any]]: |
| return [ |
| InputParam("output_type", default="pil"), |
| InputParam( |
| "latents", |
| required=True, |
| type_hint=torch.Tensor, |
| description="The denoised latents from the denoising step", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[str]: |
| return [ |
| OutputParam( |
| "images", |
| type_hint=list[PIL.Image.Image] | list[torch.Tensor] | list[np.array], |
| description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", |
| ) |
| ] |
|
|
| @staticmethod |
| |
| def upcast_vae(components): |
| deprecate( |
| "upcast_vae", |
| "1.0.0", |
| "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", |
| ) |
| components.vae.to(dtype=torch.float32) |
|
|
| @torch.no_grad() |
| def __call__(self, components, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| if not block_state.output_type == "latent": |
| latents = block_state.latents |
| |
| block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast |
|
|
| if block_state.needs_upcasting: |
| self.upcast_vae(components) |
| latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) |
| elif latents.dtype != components.vae.dtype: |
| if torch.backends.mps.is_available(): |
| |
| components.vae = components.vae.to(latents.dtype) |
|
|
| |
| |
| block_state.has_latents_mean = ( |
| hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None |
| ) |
| block_state.has_latents_std = ( |
| hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None |
| ) |
| if block_state.has_latents_mean and block_state.has_latents_std: |
| block_state.latents_mean = ( |
| torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) |
| ) |
| block_state.latents_std = ( |
| torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) |
| ) |
| latents = ( |
| latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean |
| ) |
| else: |
| latents = latents / components.vae.config.scaling_factor |
|
|
| block_state.images = components.vae.decode(latents, return_dict=False)[0] |
|
|
| |
| if block_state.needs_upcasting: |
| components.vae.to(dtype=torch.float16) |
| else: |
| block_state.images = block_state.latents |
|
|
| |
| if hasattr(components, "watermark") and components.watermark is not None: |
| block_state.images = components.watermark.apply_watermark(block_state.images) |
|
|
| block_state.images = components.image_processor.postprocess( |
| block_state.images, output_type=block_state.output_type |
| ) |
|
|
| self.set_block_state(state, block_state) |
|
|
| return components, state |
|
|
|
|
| class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks): |
| model_name = "stable-diffusion-xl" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "A post-processing step that overlays the mask on the image (inpainting task only).\n" |
| + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" |
| ) |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ |
| ComponentSpec( |
| "image_processor", |
| VaeImageProcessor, |
| config=FrozenDict({"vae_scale_factor": 8}), |
| default_creation_method="from_config", |
| ), |
| ] |
|
|
| @property |
| def inputs(self) -> list[tuple[str, Any]]: |
| return [ |
| InputParam("image"), |
| InputParam("mask_image"), |
| InputParam("padding_mask_crop"), |
| InputParam( |
| "images", |
| type_hint=list[PIL.Image.Image] | list[torch.Tensor] | list[np.array], |
| description="The generated images from the decode step", |
| ), |
| InputParam( |
| "crops_coords", |
| type_hint=tuple[int, int], |
| description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: |
| block_state.images = [ |
| components.image_processor.apply_overlay( |
| block_state.mask_image, block_state.image, i, block_state.crops_coords |
| ) |
| for i in block_state.images |
| ] |
|
|
| self.set_block_state(state, block_state) |
|
|
| return components, state |
|
|