| """Modular SDXL Upscale - consolidated Hub block. |
| |
| MultiDiffusion tiled upscaling for Stable Diffusion XL using Modular Diffusers. |
| """ |
|
|
|
|
| |
| |
| |
|
|
| """Tile planning and cosine blending weights for MultiDiffusion.""" |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
|
|
|
|
| @dataclass |
| class LatentTileSpec: |
| """Tile specification in latent space. |
| |
| Attributes: |
| y: Top edge in latent pixels. |
| x: Left edge in latent pixels. |
| h: Height in latent pixels. |
| w: Width in latent pixels. |
| """ |
|
|
| y: int |
| x: int |
| h: int |
| w: int |
|
|
|
|
| def validate_tile_params(tile_size: int, overlap: int) -> None: |
| if tile_size <= 0: |
| raise ValueError(f"`tile_size` must be positive, got {tile_size}.") |
| if overlap < 0: |
| raise ValueError(f"`overlap` must be non-negative, got {overlap}.") |
| if overlap >= tile_size: |
| raise ValueError( |
| f"`overlap` must be less than `tile_size`. " |
| f"Got overlap={overlap}, tile_size={tile_size}." |
| ) |
|
|
|
|
| def plan_latent_tiles( |
| latent_h: int, |
| latent_w: int, |
| tile_size: int = 64, |
| overlap: int = 8, |
| ) -> list[LatentTileSpec]: |
| """Plan overlapping tiles in latent space for MultiDiffusion. |
| |
| Tiles overlap by ``overlap`` latent pixels. Edge tiles are clamped to |
| the latent bounds. |
| """ |
| validate_tile_params(tile_size, overlap) |
|
|
| stride = tile_size - overlap |
| tiles: list[LatentTileSpec] = [] |
|
|
| y = 0 |
| while y < latent_h: |
| h = min(tile_size, latent_h - y) |
| if h < tile_size and y > 0: |
| y = max(0, latent_h - tile_size) |
| h = latent_h - y |
|
|
| x = 0 |
| while x < latent_w: |
| w = min(tile_size, latent_w - x) |
| if w < tile_size and x > 0: |
| x = max(0, latent_w - tile_size) |
| w = latent_w - x |
|
|
| tiles.append(LatentTileSpec(y=y, x=x, h=h, w=w)) |
|
|
| if x + w >= latent_w: |
| break |
| x += stride |
|
|
| if y + h >= latent_h: |
| break |
| y += stride |
|
|
| return tiles |
|
|
|
|
| def make_cosine_tile_weight( |
| h: int, |
| w: int, |
| overlap: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| is_top: bool = False, |
| is_bottom: bool = False, |
| is_left: bool = False, |
| is_right: bool = False, |
| ) -> torch.Tensor: |
| """Boundary-aware cosine blending weight for one tile. |
| |
| Returns shape (1, 1, h, w). Canvas-edge sides get weight 1.0 (no fade), |
| interior overlap regions get a half-cosine ramp from 0 to 1. |
| """ |
| import math |
|
|
| wy = torch.ones(h, device=device, dtype=dtype) |
| wx = torch.ones(w, device=device, dtype=dtype) |
|
|
| ramp = min(overlap, h // 2, w // 2) |
| if ramp <= 0: |
| return torch.ones(1, 1, h, w, device=device, dtype=dtype) |
|
|
| cos_ramp = torch.tensor( |
| [0.5 * (1 - math.cos(math.pi * i / ramp)) for i in range(ramp)], |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| if not is_top: |
| wy[:ramp] = cos_ramp |
| if not is_bottom: |
| wy[-ramp:] = cos_ramp.flip(0) |
| if not is_left: |
| wx[:ramp] = cos_ramp |
| if not is_right: |
| wx[-ramp:] = cos_ramp.flip(0) |
|
|
| weight = wy[:, None] * wx[None, :] |
| return weight.unsqueeze(0).unsqueeze(0) |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Input steps for Modular SDXL Upscale: text encoding, Lanczos upscale.""" |
|
|
| import PIL.Image |
| import torch |
|
|
| from diffusers.utils import logging |
| from diffusers.modular_pipelines.modular_pipeline import ModularPipelineBlocks, PipelineState |
| from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam |
| from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLTextEncoderStep |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class UltimateSDUpscaleTextEncoderStep(StableDiffusionXLTextEncoderStep): |
| """SDXL text encoder step that applies guidance scale before encoding. |
| |
| Syncs the guider's guidance_scale before prompt encoding so that |
| unconditional embeddings are always produced when CFG is active. |
| |
| Also applies a default negative prompt for upscaling when the user |
| does not provide one. |
| """ |
|
|
| DEFAULT_NEGATIVE_PROMPT = "blurry, low quality, artifacts, noise, jpeg compression" |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return super().inputs + [ |
| InputParam( |
| "guidance_scale", |
| type_hint=float, |
| default=7.5, |
| description="Classifier-Free Guidance scale.", |
| ), |
| InputParam( |
| "use_default_negative", |
| type_hint=bool, |
| default=True, |
| description="Apply default negative prompt when none is provided.", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
| guidance_scale = getattr(block_state, "guidance_scale", 7.5) |
|
|
| if hasattr(components, "guider") and components.guider is not None: |
| components.guider.guidance_scale = guidance_scale |
|
|
| use_default_negative = getattr(block_state, "use_default_negative", True) |
| if use_default_negative: |
| neg = getattr(block_state, "negative_prompt", None) |
| if neg is None or neg == "": |
| block_state.negative_prompt = self.DEFAULT_NEGATIVE_PROMPT |
| state.set("negative_prompt", self.DEFAULT_NEGATIVE_PROMPT) |
|
|
| return super().__call__(components, state) |
|
|
|
|
| class UltimateSDUpscaleUpscaleStep(ModularPipelineBlocks): |
| """Upscales the input image using Lanczos interpolation.""" |
|
|
| @property |
| def description(self) -> str: |
| return "Upscale input image using Lanczos interpolation." |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam("image", type_hint=PIL.Image.Image, required=True, |
| description="Input image to upscale."), |
| InputParam("upscale_factor", type_hint=float, default=2.0, |
| description="Scale multiplier."), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam("upscaled_image", type_hint=PIL.Image.Image), |
| OutputParam("upscaled_width", type_hint=int), |
| OutputParam("upscaled_height", type_hint=int), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| image = block_state.image |
| upscale_factor = block_state.upscale_factor |
|
|
| if not isinstance(image, PIL.Image.Image): |
| raise ValueError(f"Expected PIL.Image, got {type(image)}.") |
|
|
| new_width = int(image.width * upscale_factor) |
| new_height = int(image.height * upscale_factor) |
|
|
| block_state.upscaled_image = image.resize((new_width, new_height), PIL.Image.LANCZOS) |
| block_state.upscaled_width = new_width |
| block_state.upscaled_height = new_height |
|
|
| logger.info(f"Upscaled {image.width}x{image.height} -> {new_width}x{new_height}") |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """MultiDiffusion tiled upscaling step for Modular SDXL Upscale. |
| |
| Blends noise predictions from overlapping latent tiles using cosine weights. |
| Reuses SDXL blocks via their public interface. |
| """ |
|
|
| import math |
| import time |
|
|
| import numpy as np |
| import PIL.Image |
| import torch |
| from tqdm.auto import tqdm |
|
|
| from diffusers.configuration_utils import FrozenDict |
| from diffusers.guiders import ClassifierFreeGuidance |
| from diffusers.image_processor import VaeImageProcessor |
| from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel |
| from diffusers.schedulers import DPMSolverMultistepScheduler, EulerDiscreteScheduler |
| from diffusers.utils import logging |
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.modular_pipelines.modular_pipeline import ( |
| ModularPipelineBlocks, |
| PipelineState, |
| ) |
| from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam |
| from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import ( |
| StableDiffusionXLControlNetInputStep, |
| StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, |
| StableDiffusionXLImg2ImgPrepareLatentsStep, |
| prepare_latents_img2img, |
| ) |
| from diffusers.modular_pipelines.stable_diffusion_xl.decoders import StableDiffusionXLDecodeStep |
| from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLVaeEncoderStep |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| def _make_state(values: dict, kwargs_type_map: dict | None = None) -> PipelineState: |
| """Create a PipelineState and set values, optionally with kwargs_type.""" |
| state = PipelineState() |
| kwargs_type_map = kwargs_type_map or {} |
| for k, v in values.items(): |
| state.set(k, v, kwargs_type_map.get(k)) |
| return state |
|
|
|
|
| def _to_pil_rgb_image(image) -> PIL.Image.Image: |
| """Convert a tensor/ndarray/PIL image to a RGB PIL image.""" |
| if isinstance(image, PIL.Image.Image): |
| return image.convert("RGB") |
|
|
| if torch.is_tensor(image): |
| tensor = image.detach().cpu() |
| if tensor.ndim == 4: |
| if tensor.shape[0] != 1: |
| raise ValueError( |
| f"`control_image` tensor batch must be 1 for tiled upscaling, got shape {tuple(tensor.shape)}." |
| ) |
| tensor = tensor[0] |
| if tensor.ndim == 3 and tensor.shape[0] in (1, 3, 4) and tensor.shape[-1] not in (1, 3, 4): |
| tensor = tensor.permute(1, 2, 0) |
| image = tensor.numpy() |
|
|
| if isinstance(image, np.ndarray): |
| array = image |
| if array.ndim == 4: |
| if array.shape[0] != 1: |
| raise ValueError( |
| f"`control_image` ndarray batch must be 1 for tiled upscaling, got shape {array.shape}." |
| ) |
| array = array[0] |
| if array.ndim == 3 and array.shape[0] in (1, 3, 4) and array.shape[-1] not in (1, 3, 4): |
| array = np.transpose(array, (1, 2, 0)) |
| if array.ndim == 2: |
| array = np.stack([array] * 3, axis=-1) |
| if array.ndim != 3: |
| raise ValueError(f"`control_image` must have 2 or 3 dimensions, got shape {array.shape}.") |
| if array.shape[-1] == 1: |
| array = np.repeat(array, 3, axis=-1) |
| if array.shape[-1] == 4: |
| array = array[..., :3] |
| if array.shape[-1] != 3: |
| raise ValueError(f"`control_image` channel dimension must be 1/3/4, got shape {array.shape}.") |
| if array.dtype != np.uint8: |
| array = np.asarray(array, dtype=np.float32) |
| max_val = float(np.max(array)) if array.size > 0 else 1.0 |
| if max_val <= 1.0: |
| array = (np.clip(array, 0.0, 1.0) * 255.0).astype(np.uint8) |
| else: |
| array = np.clip(array, 0.0, 255.0).astype(np.uint8) |
| return PIL.Image.fromarray(array).convert("RGB") |
|
|
| raise ValueError( |
| f"Unsupported `control_image` type {type(image)}. Expected PIL.Image, torch.Tensor, or numpy.ndarray." |
| ) |
|
|
|
|
| |
| |
| |
|
|
| _SCHEDULER_ALIASES = { |
| "euler": "EulerDiscreteScheduler", |
| "euler discrete": "EulerDiscreteScheduler", |
| "eulerdiscretescheduler": "EulerDiscreteScheduler", |
| "dpm++ 2m": "DPMSolverMultistepScheduler", |
| "dpmsolvermultistepscheduler": "DPMSolverMultistepScheduler", |
| "dpm++ 2m karras": "DPMSolverMultistepScheduler+karras", |
| } |
|
|
|
|
| def _swap_scheduler(components, scheduler_name: str): |
| """Swap the scheduler on ``components`` given a human-readable name. |
| |
| Supported names (case-insensitive): |
| - ``"Euler"`` / ``"EulerDiscreteScheduler"`` |
| - ``"DPM++ 2M"`` / ``"DPMSolverMultistepScheduler"`` |
| - ``"DPM++ 2M Karras"`` (DPMSolverMultistep with Karras sigmas) |
| |
| If the requested scheduler is already active, this is a no-op. |
| """ |
| key = scheduler_name.strip().lower() |
| resolved = _SCHEDULER_ALIASES.get(key, key) |
|
|
| use_karras = resolved.endswith("+karras") |
| if use_karras: |
| resolved = resolved.replace("+karras", "") |
|
|
| current = type(components.scheduler).__name__ |
|
|
| if resolved == "EulerDiscreteScheduler": |
| if current != "EulerDiscreteScheduler": |
| components.scheduler = EulerDiscreteScheduler.from_config(components.scheduler.config) |
| logger.info("Swapped scheduler to EulerDiscreteScheduler") |
| elif resolved == "DPMSolverMultistepScheduler": |
| if current != "DPMSolverMultistepScheduler" or ( |
| use_karras and not getattr(components.scheduler.config, "use_karras_sigmas", False) |
| ): |
| extra_kwargs = {} |
| if use_karras: |
| extra_kwargs["use_karras_sigmas"] = True |
| components.scheduler = DPMSolverMultistepScheduler.from_config( |
| components.scheduler.config, **extra_kwargs |
| ) |
| logger.info(f"Swapped scheduler to DPMSolverMultistepScheduler (karras={use_karras})") |
| else: |
| logger.warning( |
| f"Unknown scheduler_name '{scheduler_name}'. Keeping current scheduler " |
| f"({current}). Supported: 'Euler', 'DPM++ 2M', 'DPM++ 2M Karras'." |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _compute_auto_strength(upscale_factor: float, pass_index: int, num_passes: int) -> float: |
| """Return the auto-scaled denoise strength for a given pass. |
| |
| Rules: |
| - Single-pass 2x: 0.3 |
| - Single-pass 4x: 0.15 |
| - Progressive passes: first pass=0.3, subsequent passes=0.2 |
| """ |
| if num_passes > 1: |
| return 0.3 if pass_index == 0 else 0.2 |
| |
| if upscale_factor <= 2.0: |
| return 0.3 |
| elif upscale_factor <= 4.0: |
| return 0.15 |
| else: |
| return 0.1 |
|
|
| class UltimateSDUpscaleMultiDiffusionStep(ModularPipelineBlocks): |
| """Single block that encodes, denoises with MultiDiffusion, and decodes. |
| |
| MultiDiffusion inverts the standard tile loop: the **outer** loop iterates |
| over timesteps and the **inner** loop iterates over overlapping latent |
| tiles. At each timestep, per-tile noise predictions are blended with |
| cosine-ramp overlap weights, then a single scheduler step is applied to |
| the full latent tensor. This eliminates tile-boundary artifacts because |
| blending happens in noise-prediction space, not pixel space. |
| |
| The full flow: |
| 1. Enable VAE tiling for memory-efficient encode/decode. |
| 2. VAE-encode the upscaled image to full-resolution latents. |
| 3. Add noise at the strength-determined level. |
| 4. For each timestep: |
| a. Plan overlapping tiles in latent space. |
| b. For each tile: crop latents, run UNet (+ optional ControlNet) |
| through the guider for CFG, accumulate weighted noise predictions. |
| c. Normalize predictions by accumulated weights. |
| d. One ``scheduler.step`` on the full blended prediction. |
| 5. VAE-decode the final latents. |
| """ |
|
|
| model_name = "stable-diffusion-xl" |
|
|
| def __init__(self): |
| super().__init__() |
| self._vae_encoder = StableDiffusionXLVaeEncoderStep() |
| self._prepare_latents = StableDiffusionXLImg2ImgPrepareLatentsStep() |
| self._prepare_add_cond = StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep() |
| self._prepare_controlnet = StableDiffusionXLControlNetInputStep() |
| self._decode = StableDiffusionXLDecodeStep() |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "MultiDiffusion tiled denoising: encodes the full upscaled image, " |
| "denoises with latent-space noise-prediction blending across " |
| "overlapping tiles, then decodes. Produces seamless output at any " |
| "resolution without tile-boundary artifacts." |
| ) |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ |
| ComponentSpec("vae", AutoencoderKL), |
| ComponentSpec( |
| "image_processor", |
| VaeImageProcessor, |
| config=FrozenDict({"vae_scale_factor": 8}), |
| default_creation_method="from_config", |
| ), |
| ComponentSpec( |
| "control_image_processor", |
| VaeImageProcessor, |
| config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), |
| default_creation_method="from_config", |
| ), |
| ComponentSpec("scheduler", EulerDiscreteScheduler), |
| ComponentSpec("unet", UNet2DConditionModel), |
| ComponentSpec("controlnet", ControlNetModel), |
| ComponentSpec( |
| "guider", |
| ClassifierFreeGuidance, |
| config=FrozenDict({"guidance_scale": 7.5}), |
| default_creation_method="from_config", |
| ), |
| ] |
|
|
| @property |
| def expected_configs(self) -> list[ConfigSpec]: |
| return [ConfigSpec("requires_aesthetics_score", False)] |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam("upscaled_image", type_hint=PIL.Image.Image, required=True), |
| InputParam("upscaled_height", type_hint=int, required=True), |
| InputParam("upscaled_width", type_hint=int, required=True), |
| InputParam("image", type_hint=PIL.Image.Image, |
| description="Original input image (before upscaling). Needed for progressive mode."), |
| InputParam("upscale_factor", type_hint=float, default=2.0, |
| description="Total upscale factor. Used for auto-strength and progressive upscaling."), |
| InputParam("generator"), |
| InputParam("batch_size", type_hint=int, required=True), |
| InputParam("num_images_per_prompt", type_hint=int, default=1), |
| InputParam("dtype", type_hint=torch.dtype, required=True), |
| InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"), |
| InputParam("num_inference_steps", type_hint=int, default=50), |
| InputParam("strength", type_hint=float, default=0.3), |
| InputParam("timesteps", type_hint=torch.Tensor, required=True), |
| InputParam("latent_timestep", type_hint=torch.Tensor, required=True), |
| InputParam("denoising_start"), |
| InputParam("denoising_end"), |
| InputParam("output_type", type_hint=str, default="pil"), |
| |
| InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"), |
| InputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), |
| InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), |
| InputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), |
| InputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), |
| InputParam("eta", type_hint=float, default=0.0), |
| |
| InputParam("guidance_scale", type_hint=float, default=7.5, |
| description="Classifier-Free Guidance scale. Higher values produce images more aligned " |
| "with the prompt at the expense of lower image quality."), |
| |
| InputParam("latent_tile_size", type_hint=int, default=64, |
| description="Tile size in latent pixels (64 = 512px). For single pass, set >= latent dims."), |
| InputParam("latent_overlap", type_hint=int, default=16, |
| description="Overlap in latent pixels (16 = 128px)."), |
| |
| InputParam("control_image", |
| description="Optional ControlNet conditioning image."), |
| InputParam("control_guidance_start", default=0.0), |
| InputParam("control_guidance_end", default=1.0), |
| InputParam("controlnet_conditioning_scale", default=1.0), |
| InputParam("guess_mode", default=False), |
| |
| InputParam("progressive", type_hint=bool, default=True, |
| description="When True and upscale_factor > 2, split into multiple 2x passes " |
| "instead of one big jump. E.g. 4x = 2x then 2x."), |
| |
| InputParam("auto_strength", type_hint=bool, default=True, |
| description="When True and user does not explicitly pass strength, automatically " |
| "scale denoise strength based on upscale factor and pass index."), |
| |
| InputParam("return_metadata", type_hint=bool, default=False, |
| description="When True, include generation metadata (sizes, passes, timings) " |
| "in the output."), |
| |
| InputParam("scheduler_name", type_hint=str, default=None, |
| description="Optional scheduler name to swap before running. " |
| "Supported: 'Euler', 'DPM++ 2M', 'DPM++ 2M Karras'."), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam("images", type_hint=list, description="Final upscaled output images."), |
| OutputParam("metadata", type_hint=dict, |
| description="Generation metadata (when return_metadata=True)."), |
| ] |
|
|
| def _run_tile_unet( |
| self, |
| components, |
| tile_latents: torch.Tensor, |
| t: int, |
| i: int, |
| block_state, |
| controlnet_cond_tile=None, |
| ) -> torch.Tensor: |
| """Run guider + UNet (+ optional ControlNet) on one tile, return noise_pred.""" |
| |
| scaled_latents = components.scheduler.scale_model_input(tile_latents, t) |
|
|
| |
| |
| pos_prompt = getattr(block_state, "prompt_embeds", None) |
| neg_prompt = getattr(block_state, "negative_prompt_embeds", None) |
| pos_pooled = getattr(block_state, "pooled_prompt_embeds", None) |
| neg_pooled = getattr(block_state, "negative_pooled_prompt_embeds", None) |
| pos_time_ids = getattr(block_state, "add_time_ids", None) |
| neg_time_ids = getattr(block_state, "negative_add_time_ids", None) |
|
|
| if neg_prompt is None and pos_prompt is not None: |
| neg_prompt = torch.zeros_like(pos_prompt) |
| if neg_pooled is None and pos_pooled is not None: |
| neg_pooled = torch.zeros_like(pos_pooled) |
| if neg_time_ids is None and pos_time_ids is not None: |
| neg_time_ids = pos_time_ids.clone() |
|
|
| guider_inputs = { |
| "prompt_embeds": (pos_prompt, neg_prompt), |
| "time_ids": (pos_time_ids, neg_time_ids), |
| "text_embeds": (pos_pooled, neg_pooled), |
| } |
|
|
| components.guider.set_state( |
| step=i, |
| num_inference_steps=block_state.num_inference_steps, |
| timestep=t, |
| ) |
| guider_state = components.guider.prepare_inputs(guider_inputs) |
|
|
| for guider_state_batch in guider_state: |
| components.guider.prepare_models(components.unet) |
|
|
| added_cond_kwargs = { |
| "text_embeds": guider_state_batch.text_embeds, |
| "time_ids": guider_state_batch.time_ids, |
| } |
|
|
| down_block_res_samples = None |
| mid_block_res_sample = None |
|
|
| |
| if ( |
| controlnet_cond_tile is not None |
| and components.controlnet is not None |
| and guider_state_batch.text_embeds is not None |
| ): |
| cn_added_cond = { |
| "text_embeds": guider_state_batch.text_embeds, |
| "time_ids": guider_state_batch.time_ids, |
| } |
| cond_scale = block_state._cn_cond_scale |
| if isinstance(block_state._cn_controlnet_keep, list) and i < len(block_state._cn_controlnet_keep): |
| keep_val = block_state._cn_controlnet_keep[i] |
| else: |
| keep_val = 1.0 |
| if isinstance(cond_scale, list): |
| cond_scale = [c * keep_val for c in cond_scale] |
| else: |
| cond_scale = cond_scale * keep_val |
|
|
| guess_mode = getattr(block_state, "guess_mode", False) |
| if guess_mode and not components.guider.is_conditional: |
| down_block_res_samples = [torch.zeros_like(s) for s in block_state._cn_zeros_down] if hasattr(block_state, "_cn_zeros_down") else None |
| mid_block_res_sample = torch.zeros_like(block_state._cn_zeros_mid) if hasattr(block_state, "_cn_zeros_mid") else None |
| else: |
| down_block_res_samples, mid_block_res_sample = components.controlnet( |
| scaled_latents, |
| t, |
| encoder_hidden_states=guider_state_batch.prompt_embeds, |
| controlnet_cond=controlnet_cond_tile, |
| conditioning_scale=cond_scale, |
| guess_mode=guess_mode, |
| added_cond_kwargs=cn_added_cond, |
| return_dict=False, |
| ) |
| if not hasattr(block_state, "_cn_zeros_down"): |
| block_state._cn_zeros_down = [torch.zeros_like(d) for d in down_block_res_samples] |
| block_state._cn_zeros_mid = torch.zeros_like(mid_block_res_sample) |
|
|
| unet_kwargs = { |
| "sample": scaled_latents, |
| "timestep": t, |
| "encoder_hidden_states": guider_state_batch.prompt_embeds, |
| "added_cond_kwargs": added_cond_kwargs, |
| "return_dict": False, |
| } |
| if down_block_res_samples is not None: |
| unet_kwargs["down_block_additional_residuals"] = down_block_res_samples |
| unet_kwargs["mid_block_additional_residual"] = mid_block_res_sample |
|
|
| guider_state_batch.noise_pred = components.unet(**unet_kwargs)[0] |
| components.guider.cleanup_models(components.unet) |
|
|
| noise_pred = components.guider(guider_state)[0] |
| return noise_pred |
|
|
| def _run_single_pass( |
| self, |
| components, |
| block_state, |
| upscaled_image: PIL.Image.Image, |
| h: int, |
| w: int, |
| ctrl_pil, |
| use_controlnet: bool, |
| latent_tile_size: int, |
| latent_overlap: int, |
| ) -> np.ndarray: |
| """Run one MultiDiffusion encode-denoise-decode pass, return decoded numpy (h, w, 3).""" |
| from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import retrieve_timesteps |
|
|
| |
| if hasattr(components.vae, "enable_tiling"): |
| components.vae.enable_tiling() |
|
|
| |
| full_controlnet_cond = None |
| if use_controlnet and ctrl_pil is not None: |
| if ctrl_pil.size != (w, h): |
| ctrl_pil = ctrl_pil.resize((w, h), PIL.Image.LANCZOS) |
|
|
| |
| enc_state = _make_state({ |
| "image": upscaled_image, |
| "height": h, |
| "width": w, |
| "generator": block_state.generator, |
| "dtype": block_state.dtype, |
| "preprocess_kwargs": None, |
| }) |
| components, enc_state = self._vae_encoder(components, enc_state) |
| image_latents = enc_state.get("image_latents") |
|
|
| |
| |
| |
| |
| |
| |
| pass_strength = block_state._current_pass_strength |
| truncated_steps = block_state.num_inference_steps |
| outer_strength = block_state.strength |
| if outer_strength > 0: |
| original_steps = max(1, round(truncated_steps / outer_strength)) |
| else: |
| original_steps = max(1, truncated_steps) |
| _ts, _nsteps = retrieve_timesteps( |
| components.scheduler, |
| original_steps, |
| components._execution_device, |
| ) |
| from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import StableDiffusionXLImg2ImgSetTimestepsStep |
| timesteps, num_inf_steps = StableDiffusionXLImg2ImgSetTimestepsStep.get_timesteps( |
| components, original_steps, pass_strength, components._execution_device, |
| ) |
| latent_timestep = timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) |
|
|
| |
| lat_state = _make_state({ |
| "image_latents": image_latents, |
| "latent_timestep": latent_timestep, |
| "batch_size": block_state.batch_size, |
| "num_images_per_prompt": block_state.num_images_per_prompt, |
| "dtype": block_state.dtype, |
| "generator": block_state.generator, |
| "latents": None, |
| "denoising_start": getattr(block_state, "denoising_start", None), |
| }) |
| components, lat_state = self._prepare_latents(components, lat_state) |
| latents = lat_state.get("latents") |
|
|
| |
| if use_controlnet and ctrl_pil is not None: |
| ctrl_state = _make_state({ |
| "control_image": ctrl_pil, |
| "control_guidance_start": getattr(block_state, "control_guidance_start", 0.0), |
| "control_guidance_end": getattr(block_state, "control_guidance_end", 1.0), |
| "controlnet_conditioning_scale": getattr(block_state, "controlnet_conditioning_scale", 1.0), |
| "guess_mode": getattr(block_state, "guess_mode", False), |
| "num_images_per_prompt": block_state.num_images_per_prompt, |
| "latents": latents, |
| "batch_size": block_state.batch_size, |
| "timesteps": timesteps, |
| "crops_coords": None, |
| }) |
| components, ctrl_state = self._prepare_controlnet(components, ctrl_state) |
| full_controlnet_cond = ctrl_state.get("controlnet_cond") |
| block_state._cn_cond_scale = ctrl_state.get("conditioning_scale") |
| block_state._cn_controlnet_keep = ctrl_state.get("controlnet_keep") |
| block_state.guess_mode = ctrl_state.get("guess_mode") |
|
|
| |
| cond_state = _make_state({ |
| "original_size": (h, w), |
| "target_size": (h, w), |
| "crops_coords_top_left": (0, 0), |
| "negative_original_size": None, |
| "negative_target_size": None, |
| "negative_crops_coords_top_left": (0, 0), |
| "num_images_per_prompt": block_state.num_images_per_prompt, |
| "aesthetic_score": 6.0, |
| "negative_aesthetic_score": 2.0, |
| "latents": latents, |
| "pooled_prompt_embeds": block_state.pooled_prompt_embeds, |
| "batch_size": block_state.batch_size, |
| }) |
| components, cond_state = self._prepare_add_cond(components, cond_state) |
| block_state.add_time_ids = cond_state.get("add_time_ids") |
| block_state.negative_add_time_ids = cond_state.get("negative_add_time_ids") |
|
|
| |
| latent_h, latent_w = latents.shape[-2], latents.shape[-1] |
| tile_specs = plan_latent_tiles(latent_h, latent_w, latent_tile_size, latent_overlap) |
| num_tiles = len(tile_specs) |
| logger.info( |
| f"MultiDiffusion: {num_tiles} latent tiles " |
| f"({latent_h}x{latent_w}, tile={latent_tile_size}, overlap={latent_overlap})" |
| ) |
|
|
| |
| guidance_scale = getattr(block_state, "guidance_scale", 7.5) |
| components.guider.guidance_scale = guidance_scale |
| disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False |
| if disable_guidance: |
| components.guider.disable() |
| else: |
| components.guider.enable() |
|
|
| |
| block_state.num_inference_steps = num_inf_steps |
|
|
| |
| vae_scale_factor = int(getattr(components, "vae_scale_factor", 8)) |
| progress_kwargs = getattr(components, "_progress_bar_config", {}) |
| if not isinstance(progress_kwargs, dict): |
| progress_kwargs = {} |
|
|
| for i, t in enumerate( |
| tqdm(timesteps, total=num_inf_steps, desc="MultiDiffusion", **progress_kwargs) |
| ): |
| noise_pred_accum = torch.zeros_like(latents, dtype=torch.float32) |
| weight_accum = torch.zeros( |
| 1, 1, latent_h, latent_w, |
| device=latents.device, dtype=torch.float32, |
| ) |
|
|
| for tile in tile_specs: |
| tile_latents = latents[:, :, tile.y:tile.y + tile.h, tile.x:tile.x + tile.w].clone() |
|
|
| cn_tile = None |
| if use_controlnet and full_controlnet_cond is not None: |
| py = tile.y * vae_scale_factor |
| px = tile.x * vae_scale_factor |
| ph = tile.h * vae_scale_factor |
| pw = tile.w * vae_scale_factor |
| cn_tile = full_controlnet_cond[:, :, py:py + ph, px:px + pw] |
|
|
| tile_noise_pred = self._run_tile_unet( |
| components, tile_latents, t, i, block_state, cn_tile, |
| ) |
|
|
| tile_weight = _make_cosine_tile_weight( |
| tile.h, tile.w, latent_overlap, |
| latents.device, torch.float32, |
| is_top=(tile.y == 0), |
| is_bottom=(tile.y + tile.h >= latent_h), |
| is_left=(tile.x == 0), |
| is_right=(tile.x + tile.w >= latent_w), |
| ) |
|
|
| noise_pred_accum[:, :, tile.y:tile.y + tile.h, tile.x:tile.x + tile.w] += ( |
| tile_noise_pred.to(torch.float32) * tile_weight |
| ) |
| weight_accum[:, :, tile.y:tile.y + tile.h, tile.x:tile.x + tile.w] += tile_weight |
|
|
| blended_noise_pred = noise_pred_accum / weight_accum.clamp(min=1e-6) |
| blended_noise_pred = torch.nan_to_num(blended_noise_pred, nan=0.0, posinf=0.0, neginf=0.0) |
| blended_noise_pred = blended_noise_pred.to(latents.dtype) |
|
|
| latents_dtype = latents.dtype |
| latents = components.scheduler.step( |
| blended_noise_pred, t, latents, |
| return_dict=False, |
| )[0] |
| if latents.dtype != latents_dtype and torch.backends.mps.is_available(): |
| latents = latents.to(latents_dtype) |
|
|
| |
| decode_state = _make_state({ |
| "latents": latents, |
| "output_type": "np", |
| }) |
| components, decode_state = self._decode(components, decode_state) |
| decoded_images = decode_state.get("images") |
| decoded_np = decoded_images[0] |
|
|
| if decoded_np.shape[0] != h or decoded_np.shape[1] != w: |
| pil_out = PIL.Image.fromarray((np.clip(decoded_np, 0, 1) * 255).astype(np.uint8)) |
| pil_out = pil_out.resize((w, h), PIL.Image.LANCZOS) |
| decoded_np = np.array(pil_out).astype(np.float32) / 255.0 |
|
|
| return decoded_np |
|
|
| @torch.no_grad() |
| def __call__(self, components, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
| t_start = time.time() |
|
|
| output_type = block_state.output_type |
| latent_tile_size = block_state.latent_tile_size |
| latent_overlap = block_state.latent_overlap |
|
|
| |
| scheduler_name = getattr(block_state, "scheduler_name", None) |
| if scheduler_name is not None: |
| _swap_scheduler(components, scheduler_name) |
|
|
| |
| upscale_factor = getattr(block_state, "upscale_factor", 2.0) |
| progressive = getattr(block_state, "progressive", True) |
| auto_strength = getattr(block_state, "auto_strength", True) |
| return_metadata = getattr(block_state, "return_metadata", False) |
|
|
| |
| |
| user_strength = block_state.strength |
| |
| |
|
|
| |
| if progressive and upscale_factor > 2.0: |
| num_passes = max(1, int(math.ceil(math.log2(upscale_factor)))) |
| else: |
| num_passes = 1 |
|
|
| |
| strength_per_pass = [] |
| for p in range(num_passes): |
| if auto_strength: |
| strength_per_pass.append( |
| _compute_auto_strength(upscale_factor, p, num_passes) |
| ) |
| else: |
| strength_per_pass.append(user_strength) |
|
|
| |
| original_image = getattr(block_state, "image", None) |
| input_w = block_state.upscaled_width |
| input_h = block_state.upscaled_height |
|
|
| |
| if original_image is not None: |
| orig_input_size = (original_image.width, original_image.height) |
| else: |
| |
| orig_input_size = ( |
| int(round(input_w / upscale_factor)), |
| int(round(input_h / upscale_factor)), |
| ) |
|
|
| |
| control_image_raw = getattr(block_state, "control_image", None) |
| use_controlnet = False |
| if control_image_raw is not None: |
| if isinstance(control_image_raw, list): |
| raise ValueError( |
| "MultiDiffusion currently supports a single `control_image`, not a list." |
| ) |
| if not hasattr(components, "controlnet") or components.controlnet is None: |
| raise ValueError( |
| "`control_image` was provided but `controlnet` component is missing. " |
| "Load a ControlNet model into `pipe.controlnet` first." |
| ) |
| use_controlnet = True |
| logger.info("MultiDiffusion: ControlNet enabled.") |
|
|
| if num_passes == 1: |
| |
| block_state._current_pass_strength = strength_per_pass[0] |
|
|
| ctrl_pil = None |
| if use_controlnet: |
| ctrl_pil = _to_pil_rgb_image(control_image_raw) |
| h, w = block_state.upscaled_height, block_state.upscaled_width |
| if ctrl_pil.size != (w, h): |
| ctrl_pil = ctrl_pil.resize((w, h), PIL.Image.LANCZOS) |
|
|
| decoded_np = self._run_single_pass( |
| components, block_state, |
| upscaled_image=block_state.upscaled_image, |
| h=block_state.upscaled_height, |
| w=block_state.upscaled_width, |
| ctrl_pil=ctrl_pil, |
| use_controlnet=use_controlnet, |
| latent_tile_size=latent_tile_size, |
| latent_overlap=latent_overlap, |
| ) |
| else: |
| |
| |
| if original_image is None: |
| |
| original_image = block_state.upscaled_image.resize( |
| orig_input_size, PIL.Image.LANCZOS, |
| ) |
|
|
| current_image = original_image |
| per_pass_factor = 2.0 |
| current_w, current_h = current_image.width, current_image.height |
|
|
| for p in range(num_passes): |
| |
| if p == num_passes - 1: |
| |
| target_w = block_state.upscaled_width |
| target_h = block_state.upscaled_height |
| else: |
| target_w = int(current_w * per_pass_factor) |
| target_h = int(current_h * per_pass_factor) |
|
|
| |
| pass_upscaled = current_image.resize((target_w, target_h), PIL.Image.LANCZOS) |
| block_state._current_pass_strength = strength_per_pass[p] |
|
|
| |
| ctrl_pil = None |
| if use_controlnet: |
| ctrl_pil = pass_upscaled.copy() |
|
|
| logger.info( |
| f"Progressive pass {p + 1}/{num_passes}: " |
| f"{current_w}x{current_h} -> {target_w}x{target_h} " |
| f"(strength={strength_per_pass[p]:.2f})" |
| ) |
|
|
| decoded_np = self._run_single_pass( |
| components, block_state, |
| upscaled_image=pass_upscaled, |
| h=target_h, |
| w=target_w, |
| ctrl_pil=ctrl_pil, |
| use_controlnet=use_controlnet, |
| latent_tile_size=latent_tile_size, |
| latent_overlap=latent_overlap, |
| ) |
|
|
| |
| result_uint8 = (np.clip(decoded_np, 0, 1) * 255).astype(np.uint8) |
| current_image = PIL.Image.fromarray(result_uint8) |
| current_w, current_h = current_image.width, current_image.height |
|
|
| |
| h = block_state.upscaled_height |
| w = block_state.upscaled_width |
| result_uint8 = (np.clip(decoded_np, 0, 1) * 255).astype(np.uint8) |
| if output_type == "pil": |
| block_state.images = [PIL.Image.fromarray(result_uint8)] |
| elif output_type == "np": |
| block_state.images = [decoded_np] |
| elif output_type == "pt": |
| block_state.images = [torch.from_numpy(decoded_np).permute(2, 0, 1).unsqueeze(0)] |
| else: |
| block_state.images = [PIL.Image.fromarray(result_uint8)] |
|
|
| |
| total_time = time.time() - t_start |
| metadata = { |
| "input_size": orig_input_size, |
| "output_size": (w, h), |
| "upscale_factor": upscale_factor, |
| "num_passes": num_passes, |
| "strength_per_pass": strength_per_pass, |
| "total_time": total_time, |
| } |
| block_state.metadata = metadata |
|
|
| if return_metadata: |
| logger.info( |
| f"MultiDiffusion complete: {orig_input_size} -> ({w}, {h}), " |
| f"{num_passes} pass(es), {total_time:.1f}s" |
| ) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
| """Block composition for Modular SDXL Upscale.""" |
|
|
| from diffusers.utils import logging |
| from diffusers.modular_pipelines.modular_pipeline import SequentialPipelineBlocks |
| from diffusers.modular_pipelines.modular_pipeline_utils import OutputParam |
| from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import ( |
| StableDiffusionXLImg2ImgSetTimestepsStep, |
| StableDiffusionXLInputStep, |
| ) |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class MultiDiffusionUpscaleBlocks(SequentialPipelineBlocks): |
| """Modular pipeline blocks for tiled SDXL upscaling with MultiDiffusion. |
| |
| Uses latent-space noise prediction blending across overlapping tiles for |
| seamless tiled upscaling at any resolution. |
| |
| Block graph:: |
| |
| [0] text_encoder - SDXL TextEncoderStep (reused) |
| [1] upscale - Lanczos resize |
| [2] input - SDXL InputStep (reused) |
| [3] set_timesteps - SDXL Img2Img SetTimestepsStep (reused) |
| [4] multidiffusion - MultiDiffusion step |
| |
| The MultiDiffusion step handles VAE encode, tiled denoise with blending, |
| and VAE decode internally, using VAE tiling for memory efficiency. |
| """ |
|
|
| block_classes = [ |
| UltimateSDUpscaleTextEncoderStep, |
| UltimateSDUpscaleUpscaleStep, |
| StableDiffusionXLInputStep, |
| StableDiffusionXLImg2ImgSetTimestepsStep, |
| UltimateSDUpscaleMultiDiffusionStep, |
| ] |
| block_names = [ |
| "text_encoder", |
| "upscale", |
| "input", |
| "set_timesteps", |
| "multidiffusion", |
| ] |
|
|
| _workflow_map = { |
| "upscale": {"image": True, "prompt": True}, |
| "upscale_controlnet": {"image": True, "control_image": True, "prompt": True}, |
| } |
|
|
| @property |
| def description(self): |
| return ( |
| "MultiDiffusion upscale pipeline for Stable Diffusion XL.\n" |
| "Upscales an input image and refines it using tiled denoising with " |
| "latent-space noise prediction blending. Produces seamless output at " |
| "any resolution without tile-boundary artifacts.\n" |
| "Supports optional ControlNet Tile conditioning for improved fidelity." |
| ) |
|
|
| @property |
| def outputs(self): |
| return [OutputParam.template("images")] |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Modular pipeline class for tiled SDXL upscaling. |
| |
| Reuses ``StableDiffusionXLModularPipeline`` since all components (VAE, UNet, |
| text encoders, scheduler) are the same. The only addition is the |
| ``default_blocks_name`` pointing to our custom block composition. |
| """ |
|
|
| from diffusers.modular_pipelines.stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline |
|
|
|
|
| class UltimateSDUpscaleModularPipeline(StableDiffusionXLModularPipeline): |
| """A ModularPipeline for tiled SDXL upscaling. |
| |
| Inherits all SDXL component properties (``vae_scale_factor``, |
| ``default_sample_size``, etc.) and overrides the default blocks to use |
| the tiled upscaling block composition. |
| |
| > [!WARNING] > This is an experimental feature and is likely to change in the future. |
| """ |
|
|
| default_blocks_name = "UltimateSDUpscaleBlocks" |
|
|
|
|