akshan-main's picture
Upload block.py with huggingface_hub
7183ccf verified
"""Modular SDXL Upscale - consolidated Hub block.
MultiDiffusion tiled upscaling for Stable Diffusion XL using Modular Diffusers.
"""
# ============================================================
# utils_tiling
# ============================================================
"""Tile planning and cosine blending weights for MultiDiffusion."""
from dataclasses import dataclass
import torch
@dataclass
class LatentTileSpec:
"""Tile specification in latent space.
Attributes:
y: Top edge in latent pixels.
x: Left edge in latent pixels.
h: Height in latent pixels.
w: Width in latent pixels.
"""
y: int
x: int
h: int
w: int
def validate_tile_params(tile_size: int, overlap: int) -> None:
if tile_size <= 0:
raise ValueError(f"`tile_size` must be positive, got {tile_size}.")
if overlap < 0:
raise ValueError(f"`overlap` must be non-negative, got {overlap}.")
if overlap >= tile_size:
raise ValueError(
f"`overlap` must be less than `tile_size`. "
f"Got overlap={overlap}, tile_size={tile_size}."
)
def plan_latent_tiles(
latent_h: int,
latent_w: int,
tile_size: int = 64,
overlap: int = 8,
) -> list[LatentTileSpec]:
"""Plan overlapping tiles in latent space for MultiDiffusion.
Tiles overlap by ``overlap`` latent pixels. Edge tiles are clamped to
the latent bounds.
"""
validate_tile_params(tile_size, overlap)
stride = tile_size - overlap
tiles: list[LatentTileSpec] = []
y = 0
while y < latent_h:
h = min(tile_size, latent_h - y)
if h < tile_size and y > 0:
y = max(0, latent_h - tile_size)
h = latent_h - y
x = 0
while x < latent_w:
w = min(tile_size, latent_w - x)
if w < tile_size and x > 0:
x = max(0, latent_w - tile_size)
w = latent_w - x
tiles.append(LatentTileSpec(y=y, x=x, h=h, w=w))
if x + w >= latent_w:
break
x += stride
if y + h >= latent_h:
break
y += stride
return tiles
def make_cosine_tile_weight(
h: int,
w: int,
overlap: int,
device: torch.device,
dtype: torch.dtype,
is_top: bool = False,
is_bottom: bool = False,
is_left: bool = False,
is_right: bool = False,
) -> torch.Tensor:
"""Boundary-aware cosine blending weight for one tile.
Returns shape (1, 1, h, w). Canvas-edge sides get weight 1.0 (no fade),
interior overlap regions get a half-cosine ramp from 0 to 1.
"""
import math
wy = torch.ones(h, device=device, dtype=dtype)
wx = torch.ones(w, device=device, dtype=dtype)
ramp = min(overlap, h // 2, w // 2)
if ramp <= 0:
return torch.ones(1, 1, h, w, device=device, dtype=dtype)
cos_ramp = torch.tensor(
[0.5 * (1 - math.cos(math.pi * i / ramp)) for i in range(ramp)],
device=device,
dtype=dtype,
)
if not is_top:
wy[:ramp] = cos_ramp
if not is_bottom:
wy[-ramp:] = cos_ramp.flip(0)
if not is_left:
wx[:ramp] = cos_ramp
if not is_right:
wx[-ramp:] = cos_ramp.flip(0)
weight = wy[:, None] * wx[None, :]
return weight.unsqueeze(0).unsqueeze(0)
# ============================================================
# input
# ============================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input steps for Modular SDXL Upscale: text encoding, Lanczos upscale."""
import PIL.Image
import torch
from diffusers.utils import logging
from diffusers.modular_pipelines.modular_pipeline import ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam
from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLTextEncoderStep
logger = logging.get_logger(__name__)
class UltimateSDUpscaleTextEncoderStep(StableDiffusionXLTextEncoderStep):
"""SDXL text encoder step that applies guidance scale before encoding.
Syncs the guider's guidance_scale before prompt encoding so that
unconditional embeddings are always produced when CFG is active.
Also applies a default negative prompt for upscaling when the user
does not provide one.
"""
DEFAULT_NEGATIVE_PROMPT = "blurry, low quality, artifacts, noise, jpeg compression"
@property
def inputs(self) -> list[InputParam]:
return super().inputs + [
InputParam(
"guidance_scale",
type_hint=float,
default=7.5,
description="Classifier-Free Guidance scale.",
),
InputParam(
"use_default_negative",
type_hint=bool,
default=True,
description="Apply default negative prompt when none is provided.",
),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
guidance_scale = getattr(block_state, "guidance_scale", 7.5)
if hasattr(components, "guider") and components.guider is not None:
components.guider.guidance_scale = guidance_scale
use_default_negative = getattr(block_state, "use_default_negative", True)
if use_default_negative:
neg = getattr(block_state, "negative_prompt", None)
if neg is None or neg == "":
block_state.negative_prompt = self.DEFAULT_NEGATIVE_PROMPT
state.set("negative_prompt", self.DEFAULT_NEGATIVE_PROMPT)
return super().__call__(components, state)
class UltimateSDUpscaleUpscaleStep(ModularPipelineBlocks):
"""Upscales the input image using Lanczos interpolation."""
@property
def description(self) -> str:
return "Upscale input image using Lanczos interpolation."
@property
def inputs(self) -> list[InputParam]:
return [
InputParam("image", type_hint=PIL.Image.Image, required=True,
description="Input image to upscale."),
InputParam("upscale_factor", type_hint=float, default=2.0,
description="Scale multiplier."),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam("upscaled_image", type_hint=PIL.Image.Image),
OutputParam("upscaled_width", type_hint=int),
OutputParam("upscaled_height", type_hint=int),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
image = block_state.image
upscale_factor = block_state.upscale_factor
if not isinstance(image, PIL.Image.Image):
raise ValueError(f"Expected PIL.Image, got {type(image)}.")
new_width = int(image.width * upscale_factor)
new_height = int(image.height * upscale_factor)
block_state.upscaled_image = image.resize((new_width, new_height), PIL.Image.LANCZOS)
block_state.upscaled_width = new_width
block_state.upscaled_height = new_height
logger.info(f"Upscaled {image.width}x{image.height} -> {new_width}x{new_height}")
self.set_block_state(state, block_state)
return components, state
# ============================================================
# denoise
# ============================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MultiDiffusion tiled upscaling step for Modular SDXL Upscale.
Blends noise predictions from overlapping latent tiles using cosine weights.
Reuses SDXL blocks via their public interface.
"""
import math
import time
import numpy as np
import PIL.Image
import torch
from tqdm.auto import tqdm
from diffusers.configuration_utils import FrozenDict
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from diffusers.schedulers import DPMSolverMultistepScheduler, EulerDiscreteScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.modular_pipelines.modular_pipeline import (
ModularPipelineBlocks,
PipelineState,
)
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import (
StableDiffusionXLControlNetInputStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
prepare_latents_img2img,
)
from diffusers.modular_pipelines.stable_diffusion_xl.decoders import StableDiffusionXLDecodeStep
from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLVaeEncoderStep
logger = logging.get_logger(__name__)
# ---------------------------------------------------------------------------
# Helper: populate a PipelineState from a dict
# ---------------------------------------------------------------------------
def _make_state(values: dict, kwargs_type_map: dict | None = None) -> PipelineState:
"""Create a PipelineState and set values, optionally with kwargs_type."""
state = PipelineState()
kwargs_type_map = kwargs_type_map or {}
for k, v in values.items():
state.set(k, v, kwargs_type_map.get(k))
return state
def _to_pil_rgb_image(image) -> PIL.Image.Image:
"""Convert a tensor/ndarray/PIL image to a RGB PIL image."""
if isinstance(image, PIL.Image.Image):
return image.convert("RGB")
if torch.is_tensor(image):
tensor = image.detach().cpu()
if tensor.ndim == 4:
if tensor.shape[0] != 1:
raise ValueError(
f"`control_image` tensor batch must be 1 for tiled upscaling, got shape {tuple(tensor.shape)}."
)
tensor = tensor[0]
if tensor.ndim == 3 and tensor.shape[0] in (1, 3, 4) and tensor.shape[-1] not in (1, 3, 4):
tensor = tensor.permute(1, 2, 0)
image = tensor.numpy()
if isinstance(image, np.ndarray):
array = image
if array.ndim == 4:
if array.shape[0] != 1:
raise ValueError(
f"`control_image` ndarray batch must be 1 for tiled upscaling, got shape {array.shape}."
)
array = array[0]
if array.ndim == 3 and array.shape[0] in (1, 3, 4) and array.shape[-1] not in (1, 3, 4):
array = np.transpose(array, (1, 2, 0))
if array.ndim == 2:
array = np.stack([array] * 3, axis=-1)
if array.ndim != 3:
raise ValueError(f"`control_image` must have 2 or 3 dimensions, got shape {array.shape}.")
if array.shape[-1] == 1:
array = np.repeat(array, 3, axis=-1)
if array.shape[-1] == 4:
array = array[..., :3]
if array.shape[-1] != 3:
raise ValueError(f"`control_image` channel dimension must be 1/3/4, got shape {array.shape}.")
if array.dtype != np.uint8:
array = np.asarray(array, dtype=np.float32)
max_val = float(np.max(array)) if array.size > 0 else 1.0
if max_val <= 1.0:
array = (np.clip(array, 0.0, 1.0) * 255.0).astype(np.uint8)
else:
array = np.clip(array, 0.0, 255.0).astype(np.uint8)
return PIL.Image.fromarray(array).convert("RGB")
raise ValueError(
f"Unsupported `control_image` type {type(image)}. Expected PIL.Image, torch.Tensor, or numpy.ndarray."
)
# ---------------------------------------------------------------------------
# Scheduler swap helper (Feature 5)
# ---------------------------------------------------------------------------
_SCHEDULER_ALIASES = {
"euler": "EulerDiscreteScheduler",
"euler discrete": "EulerDiscreteScheduler",
"eulerdiscretescheduler": "EulerDiscreteScheduler",
"dpm++ 2m": "DPMSolverMultistepScheduler",
"dpmsolvermultistepscheduler": "DPMSolverMultistepScheduler",
"dpm++ 2m karras": "DPMSolverMultistepScheduler+karras",
}
def _swap_scheduler(components, scheduler_name: str):
"""Swap the scheduler on ``components`` given a human-readable name.
Supported names (case-insensitive):
- ``"Euler"`` / ``"EulerDiscreteScheduler"``
- ``"DPM++ 2M"`` / ``"DPMSolverMultistepScheduler"``
- ``"DPM++ 2M Karras"`` (DPMSolverMultistep with Karras sigmas)
If the requested scheduler is already active, this is a no-op.
"""
key = scheduler_name.strip().lower()
resolved = _SCHEDULER_ALIASES.get(key, key)
use_karras = resolved.endswith("+karras")
if use_karras:
resolved = resolved.replace("+karras", "")
current = type(components.scheduler).__name__
if resolved == "EulerDiscreteScheduler":
if current != "EulerDiscreteScheduler":
components.scheduler = EulerDiscreteScheduler.from_config(components.scheduler.config)
logger.info("Swapped scheduler to EulerDiscreteScheduler")
elif resolved == "DPMSolverMultistepScheduler":
if current != "DPMSolverMultistepScheduler" or (
use_karras and not getattr(components.scheduler.config, "use_karras_sigmas", False)
):
extra_kwargs = {}
if use_karras:
extra_kwargs["use_karras_sigmas"] = True
components.scheduler = DPMSolverMultistepScheduler.from_config(
components.scheduler.config, **extra_kwargs
)
logger.info(f"Swapped scheduler to DPMSolverMultistepScheduler (karras={use_karras})")
else:
logger.warning(
f"Unknown scheduler_name '{scheduler_name}'. Keeping current scheduler "
f"({current}). Supported: 'Euler', 'DPM++ 2M', 'DPM++ 2M Karras'."
)
# ---------------------------------------------------------------------------
# Auto-strength helper (Feature 2)
# ---------------------------------------------------------------------------
def _compute_auto_strength(upscale_factor: float, pass_index: int, num_passes: int) -> float:
"""Return the auto-scaled denoise strength for a given pass.
Rules:
- Single-pass 2x: 0.3
- Single-pass 4x: 0.15
- Progressive passes: first pass=0.3, subsequent passes=0.2
"""
if num_passes > 1:
return 0.3 if pass_index == 0 else 0.2
# Single pass
if upscale_factor <= 2.0:
return 0.3
elif upscale_factor <= 4.0:
return 0.15
else:
return 0.1
class UltimateSDUpscaleMultiDiffusionStep(ModularPipelineBlocks):
"""Single block that encodes, denoises with MultiDiffusion, and decodes.
MultiDiffusion inverts the standard tile loop: the **outer** loop iterates
over timesteps and the **inner** loop iterates over overlapping latent
tiles. At each timestep, per-tile noise predictions are blended with
cosine-ramp overlap weights, then a single scheduler step is applied to
the full latent tensor. This eliminates tile-boundary artifacts because
blending happens in noise-prediction space, not pixel space.
The full flow:
1. Enable VAE tiling for memory-efficient encode/decode.
2. VAE-encode the upscaled image to full-resolution latents.
3. Add noise at the strength-determined level.
4. For each timestep:
a. Plan overlapping tiles in latent space.
b. For each tile: crop latents, run UNet (+ optional ControlNet)
through the guider for CFG, accumulate weighted noise predictions.
c. Normalize predictions by accumulated weights.
d. One ``scheduler.step`` on the full blended prediction.
5. VAE-decode the final latents.
"""
model_name = "stable-diffusion-xl"
def __init__(self):
super().__init__()
self._vae_encoder = StableDiffusionXLVaeEncoderStep()
self._prepare_latents = StableDiffusionXLImg2ImgPrepareLatentsStep()
self._prepare_add_cond = StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep()
self._prepare_controlnet = StableDiffusionXLControlNetInputStep()
self._decode = StableDiffusionXLDecodeStep()
@property
def description(self) -> str:
return (
"MultiDiffusion tiled denoising: encodes the full upscaled image, "
"denoises with latent-space noise-prediction blending across "
"overlapping tiles, then decodes. Produces seamless output at any "
"resolution without tile-boundary artifacts."
)
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
ComponentSpec(
"control_image_processor",
VaeImageProcessor,
config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}),
default_creation_method="from_config",
),
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec("controlnet", ControlNetModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> list[ConfigSpec]:
return [ConfigSpec("requires_aesthetics_score", False)]
@property
def inputs(self) -> list[InputParam]:
return [
InputParam("upscaled_image", type_hint=PIL.Image.Image, required=True),
InputParam("upscaled_height", type_hint=int, required=True),
InputParam("upscaled_width", type_hint=int, required=True),
InputParam("image", type_hint=PIL.Image.Image,
description="Original input image (before upscaling). Needed for progressive mode."),
InputParam("upscale_factor", type_hint=float, default=2.0,
description="Total upscale factor. Used for auto-strength and progressive upscaling."),
InputParam("generator"),
InputParam("batch_size", type_hint=int, required=True),
InputParam("num_images_per_prompt", type_hint=int, default=1),
InputParam("dtype", type_hint=torch.dtype, required=True),
InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"),
InputParam("num_inference_steps", type_hint=int, default=50),
InputParam("strength", type_hint=float, default=0.3),
InputParam("timesteps", type_hint=torch.Tensor, required=True),
InputParam("latent_timestep", type_hint=torch.Tensor, required=True),
InputParam("denoising_start"),
InputParam("denoising_end"),
InputParam("output_type", type_hint=str, default="pil"),
# Prompt embeddings for guider (kwargs_type must match text encoder outputs)
InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"),
InputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
InputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
InputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
InputParam("eta", type_hint=float, default=0.0),
# Guidance scale for CFG
InputParam("guidance_scale", type_hint=float, default=7.5,
description="Classifier-Free Guidance scale. Higher values produce images more aligned "
"with the prompt at the expense of lower image quality."),
# MultiDiffusion params
InputParam("latent_tile_size", type_hint=int, default=64,
description="Tile size in latent pixels (64 = 512px). For single pass, set >= latent dims."),
InputParam("latent_overlap", type_hint=int, default=16,
description="Overlap in latent pixels (16 = 128px)."),
# ControlNet params
InputParam("control_image",
description="Optional ControlNet conditioning image."),
InputParam("control_guidance_start", default=0.0),
InputParam("control_guidance_end", default=1.0),
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False),
# Progressive upscaling (Feature 1)
InputParam("progressive", type_hint=bool, default=True,
description="When True and upscale_factor > 2, split into multiple 2x passes "
"instead of one big jump. E.g. 4x = 2x then 2x."),
# Auto-strength (Feature 2)
InputParam("auto_strength", type_hint=bool, default=True,
description="When True and user does not explicitly pass strength, automatically "
"scale denoise strength based on upscale factor and pass index."),
# Output metadata (Feature 4)
InputParam("return_metadata", type_hint=bool, default=False,
description="When True, include generation metadata (sizes, passes, timings) "
"in the output."),
# Scheduler selection (Feature 5)
InputParam("scheduler_name", type_hint=str, default=None,
description="Optional scheduler name to swap before running. "
"Supported: 'Euler', 'DPM++ 2M', 'DPM++ 2M Karras'."),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam("images", type_hint=list, description="Final upscaled output images."),
OutputParam("metadata", type_hint=dict,
description="Generation metadata (when return_metadata=True)."),
]
def _run_tile_unet(
self,
components,
tile_latents: torch.Tensor,
t: int,
i: int,
block_state,
controlnet_cond_tile=None,
) -> torch.Tensor:
"""Run guider + UNet (+ optional ControlNet) on one tile, return noise_pred."""
# Scale input
scaled_latents = components.scheduler.scale_model_input(tile_latents, t)
# Guider inputs — ensure negative embeddings are never None so the
# unconditional CFG batch gets valid tensors for UNet + ControlNet.
pos_prompt = getattr(block_state, "prompt_embeds", None)
neg_prompt = getattr(block_state, "negative_prompt_embeds", None)
pos_pooled = getattr(block_state, "pooled_prompt_embeds", None)
neg_pooled = getattr(block_state, "negative_pooled_prompt_embeds", None)
pos_time_ids = getattr(block_state, "add_time_ids", None)
neg_time_ids = getattr(block_state, "negative_add_time_ids", None)
if neg_prompt is None and pos_prompt is not None:
neg_prompt = torch.zeros_like(pos_prompt)
if neg_pooled is None and pos_pooled is not None:
neg_pooled = torch.zeros_like(pos_pooled)
if neg_time_ids is None and pos_time_ids is not None:
neg_time_ids = pos_time_ids.clone()
guider_inputs = {
"prompt_embeds": (pos_prompt, neg_prompt),
"time_ids": (pos_time_ids, neg_time_ids),
"text_embeds": (pos_pooled, neg_pooled),
}
components.guider.set_state(
step=i,
num_inference_steps=block_state.num_inference_steps,
timestep=t,
)
guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
added_cond_kwargs = {
"text_embeds": guider_state_batch.text_embeds,
"time_ids": guider_state_batch.time_ids,
}
down_block_res_samples = None
mid_block_res_sample = None
# ControlNet forward pass (skip for unconditional batch where text_embeds is None)
if (
controlnet_cond_tile is not None
and components.controlnet is not None
and guider_state_batch.text_embeds is not None
):
cn_added_cond = {
"text_embeds": guider_state_batch.text_embeds,
"time_ids": guider_state_batch.time_ids,
}
cond_scale = block_state._cn_cond_scale
if isinstance(block_state._cn_controlnet_keep, list) and i < len(block_state._cn_controlnet_keep):
keep_val = block_state._cn_controlnet_keep[i]
else:
keep_val = 1.0
if isinstance(cond_scale, list):
cond_scale = [c * keep_val for c in cond_scale]
else:
cond_scale = cond_scale * keep_val
guess_mode = getattr(block_state, "guess_mode", False)
if guess_mode and not components.guider.is_conditional:
down_block_res_samples = [torch.zeros_like(s) for s in block_state._cn_zeros_down] if hasattr(block_state, "_cn_zeros_down") else None
mid_block_res_sample = torch.zeros_like(block_state._cn_zeros_mid) if hasattr(block_state, "_cn_zeros_mid") else None
else:
down_block_res_samples, mid_block_res_sample = components.controlnet(
scaled_latents,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
controlnet_cond=controlnet_cond_tile,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=cn_added_cond,
return_dict=False,
)
if not hasattr(block_state, "_cn_zeros_down"):
block_state._cn_zeros_down = [torch.zeros_like(d) for d in down_block_res_samples]
block_state._cn_zeros_mid = torch.zeros_like(mid_block_res_sample)
unet_kwargs = {
"sample": scaled_latents,
"timestep": t,
"encoder_hidden_states": guider_state_batch.prompt_embeds,
"added_cond_kwargs": added_cond_kwargs,
"return_dict": False,
}
if down_block_res_samples is not None:
unet_kwargs["down_block_additional_residuals"] = down_block_res_samples
unet_kwargs["mid_block_additional_residual"] = mid_block_res_sample
guider_state_batch.noise_pred = components.unet(**unet_kwargs)[0]
components.guider.cleanup_models(components.unet)
noise_pred = components.guider(guider_state)[0]
return noise_pred
def _run_single_pass(
self,
components,
block_state,
upscaled_image: PIL.Image.Image,
h: int,
w: int,
ctrl_pil,
use_controlnet: bool,
latent_tile_size: int,
latent_overlap: int,
) -> np.ndarray:
"""Run one MultiDiffusion encode-denoise-decode pass, return decoded numpy (h, w, 3)."""
from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import retrieve_timesteps
# --- Enable VAE tiling ---
if hasattr(components.vae, "enable_tiling"):
components.vae.enable_tiling()
# --- ControlNet setup for this pass ---
full_controlnet_cond = None
if use_controlnet and ctrl_pil is not None:
if ctrl_pil.size != (w, h):
ctrl_pil = ctrl_pil.resize((w, h), PIL.Image.LANCZOS)
# --- VAE encode ---
enc_state = _make_state({
"image": upscaled_image,
"height": h,
"width": w,
"generator": block_state.generator,
"dtype": block_state.dtype,
"preprocess_kwargs": None,
})
components, enc_state = self._vae_encoder(components, enc_state)
image_latents = enc_state.get("image_latents")
# --- Re-compute timesteps for this pass's strength ---
# The outer set_timesteps block stores num_inference_steps = int(original * strength)
# in the state (the number of denoising steps after strength truncation).
# To recover the original step count, we use: original = round(truncated / outer_strength).
# We use the user-provided strength (block_state.strength) as the outer strength,
# and _current_pass_strength as this pass's actual denoising strength.
pass_strength = block_state._current_pass_strength
truncated_steps = block_state.num_inference_steps
outer_strength = block_state.strength
if outer_strength > 0:
original_steps = max(1, round(truncated_steps / outer_strength))
else:
original_steps = max(1, truncated_steps)
_ts, _nsteps = retrieve_timesteps(
components.scheduler,
original_steps,
components._execution_device,
)
from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import StableDiffusionXLImg2ImgSetTimestepsStep
timesteps, num_inf_steps = StableDiffusionXLImg2ImgSetTimestepsStep.get_timesteps(
components, original_steps, pass_strength, components._execution_device,
)
latent_timestep = timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt)
# --- Prepare latents (add noise) ---
lat_state = _make_state({
"image_latents": image_latents,
"latent_timestep": latent_timestep,
"batch_size": block_state.batch_size,
"num_images_per_prompt": block_state.num_images_per_prompt,
"dtype": block_state.dtype,
"generator": block_state.generator,
"latents": None,
"denoising_start": getattr(block_state, "denoising_start", None),
})
components, lat_state = self._prepare_latents(components, lat_state)
latents = lat_state.get("latents")
# ControlNet conditioning
if use_controlnet and ctrl_pil is not None:
ctrl_state = _make_state({
"control_image": ctrl_pil,
"control_guidance_start": getattr(block_state, "control_guidance_start", 0.0),
"control_guidance_end": getattr(block_state, "control_guidance_end", 1.0),
"controlnet_conditioning_scale": getattr(block_state, "controlnet_conditioning_scale", 1.0),
"guess_mode": getattr(block_state, "guess_mode", False),
"num_images_per_prompt": block_state.num_images_per_prompt,
"latents": latents,
"batch_size": block_state.batch_size,
"timesteps": timesteps,
"crops_coords": None,
})
components, ctrl_state = self._prepare_controlnet(components, ctrl_state)
full_controlnet_cond = ctrl_state.get("controlnet_cond")
block_state._cn_cond_scale = ctrl_state.get("conditioning_scale")
block_state._cn_controlnet_keep = ctrl_state.get("controlnet_keep")
block_state.guess_mode = ctrl_state.get("guess_mode")
# --- Additional conditioning ---
cond_state = _make_state({
"original_size": (h, w),
"target_size": (h, w),
"crops_coords_top_left": (0, 0),
"negative_original_size": None,
"negative_target_size": None,
"negative_crops_coords_top_left": (0, 0),
"num_images_per_prompt": block_state.num_images_per_prompt,
"aesthetic_score": 6.0,
"negative_aesthetic_score": 2.0,
"latents": latents,
"pooled_prompt_embeds": block_state.pooled_prompt_embeds,
"batch_size": block_state.batch_size,
})
components, cond_state = self._prepare_add_cond(components, cond_state)
block_state.add_time_ids = cond_state.get("add_time_ids")
block_state.negative_add_time_ids = cond_state.get("negative_add_time_ids")
# --- Plan latent tiles ---
latent_h, latent_w = latents.shape[-2], latents.shape[-1]
tile_specs = plan_latent_tiles(latent_h, latent_w, latent_tile_size, latent_overlap)
num_tiles = len(tile_specs)
logger.info(
f"MultiDiffusion: {num_tiles} latent tiles "
f"({latent_h}x{latent_w}, tile={latent_tile_size}, overlap={latent_overlap})"
)
# --- Guider setup ---
guidance_scale = getattr(block_state, "guidance_scale", 7.5)
components.guider.guidance_scale = guidance_scale
disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False
if disable_guidance:
components.guider.disable()
else:
components.guider.enable()
# Update block_state with this pass's timestep info
block_state.num_inference_steps = num_inf_steps
# --- MultiDiffusion denoise loop ---
vae_scale_factor = int(getattr(components, "vae_scale_factor", 8))
progress_kwargs = getattr(components, "_progress_bar_config", {})
if not isinstance(progress_kwargs, dict):
progress_kwargs = {}
for i, t in enumerate(
tqdm(timesteps, total=num_inf_steps, desc="MultiDiffusion", **progress_kwargs)
):
noise_pred_accum = torch.zeros_like(latents, dtype=torch.float32)
weight_accum = torch.zeros(
1, 1, latent_h, latent_w,
device=latents.device, dtype=torch.float32,
)
for tile in tile_specs:
tile_latents = latents[:, :, tile.y:tile.y + tile.h, tile.x:tile.x + tile.w].clone()
cn_tile = None
if use_controlnet and full_controlnet_cond is not None:
py = tile.y * vae_scale_factor
px = tile.x * vae_scale_factor
ph = tile.h * vae_scale_factor
pw = tile.w * vae_scale_factor
cn_tile = full_controlnet_cond[:, :, py:py + ph, px:px + pw]
tile_noise_pred = self._run_tile_unet(
components, tile_latents, t, i, block_state, cn_tile,
)
tile_weight = _make_cosine_tile_weight(
tile.h, tile.w, latent_overlap,
latents.device, torch.float32,
is_top=(tile.y == 0),
is_bottom=(tile.y + tile.h >= latent_h),
is_left=(tile.x == 0),
is_right=(tile.x + tile.w >= latent_w),
)
noise_pred_accum[:, :, tile.y:tile.y + tile.h, tile.x:tile.x + tile.w] += (
tile_noise_pred.to(torch.float32) * tile_weight
)
weight_accum[:, :, tile.y:tile.y + tile.h, tile.x:tile.x + tile.w] += tile_weight
blended_noise_pred = noise_pred_accum / weight_accum.clamp(min=1e-6)
blended_noise_pred = torch.nan_to_num(blended_noise_pred, nan=0.0, posinf=0.0, neginf=0.0)
blended_noise_pred = blended_noise_pred.to(latents.dtype)
latents_dtype = latents.dtype
latents = components.scheduler.step(
blended_noise_pred, t, latents,
return_dict=False,
)[0]
if latents.dtype != latents_dtype and torch.backends.mps.is_available():
latents = latents.to(latents_dtype)
# --- Decode ---
decode_state = _make_state({
"latents": latents,
"output_type": "np",
})
components, decode_state = self._decode(components, decode_state)
decoded_images = decode_state.get("images")
decoded_np = decoded_images[0]
if decoded_np.shape[0] != h or decoded_np.shape[1] != w:
pil_out = PIL.Image.fromarray((np.clip(decoded_np, 0, 1) * 255).astype(np.uint8))
pil_out = pil_out.resize((w, h), PIL.Image.LANCZOS)
decoded_np = np.array(pil_out).astype(np.float32) / 255.0
return decoded_np
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
t_start = time.time()
output_type = block_state.output_type
latent_tile_size = block_state.latent_tile_size
latent_overlap = block_state.latent_overlap
# --- Feature 5: Scheduler swap ---
scheduler_name = getattr(block_state, "scheduler_name", None)
if scheduler_name is not None:
_swap_scheduler(components, scheduler_name)
# --- Feature 1 & 2: Progressive upscaling + auto-strength ---
upscale_factor = getattr(block_state, "upscale_factor", 2.0)
progressive = getattr(block_state, "progressive", True)
auto_strength = getattr(block_state, "auto_strength", True)
return_metadata = getattr(block_state, "return_metadata", False)
# Determine if user explicitly set strength (not using default)
# The default in InputParam is 0.3; if auto_strength is True we override it.
user_strength = block_state.strength
# We treat strength=0.3 (the InputParam default) as "not explicitly set" when
# auto_strength is enabled. Users who truly want 0.3 can set auto_strength=False.
# Determine number of progressive passes
if progressive and upscale_factor > 2.0:
num_passes = max(1, int(math.ceil(math.log2(upscale_factor))))
else:
num_passes = 1
# Compute strength per pass
strength_per_pass = []
for p in range(num_passes):
if auto_strength:
strength_per_pass.append(
_compute_auto_strength(upscale_factor, p, num_passes)
)
else:
strength_per_pass.append(user_strength)
# Original input image for progressive mode
original_image = getattr(block_state, "image", None)
input_w = block_state.upscaled_width
input_h = block_state.upscaled_height
# For tracking the original pre-upscale size
if original_image is not None:
orig_input_size = (original_image.width, original_image.height)
else:
# Infer from upscale_factor
orig_input_size = (
int(round(input_w / upscale_factor)),
int(round(input_h / upscale_factor)),
)
# --- ControlNet setup ---
control_image_raw = getattr(block_state, "control_image", None)
use_controlnet = False
if control_image_raw is not None:
if isinstance(control_image_raw, list):
raise ValueError(
"MultiDiffusion currently supports a single `control_image`, not a list."
)
if not hasattr(components, "controlnet") or components.controlnet is None:
raise ValueError(
"`control_image` was provided but `controlnet` component is missing. "
"Load a ControlNet model into `pipe.controlnet` first."
)
use_controlnet = True
logger.info("MultiDiffusion: ControlNet enabled.")
if num_passes == 1:
# --- Single pass (original behavior) ---
block_state._current_pass_strength = strength_per_pass[0]
ctrl_pil = None
if use_controlnet:
ctrl_pil = _to_pil_rgb_image(control_image_raw)
h, w = block_state.upscaled_height, block_state.upscaled_width
if ctrl_pil.size != (w, h):
ctrl_pil = ctrl_pil.resize((w, h), PIL.Image.LANCZOS)
decoded_np = self._run_single_pass(
components, block_state,
upscaled_image=block_state.upscaled_image,
h=block_state.upscaled_height,
w=block_state.upscaled_width,
ctrl_pil=ctrl_pil,
use_controlnet=use_controlnet,
latent_tile_size=latent_tile_size,
latent_overlap=latent_overlap,
)
else:
# --- Progressive multi-pass ---
# Start from the original (pre-upscale) image
if original_image is None:
# Fall back to downscaling the upscaled_image back
original_image = block_state.upscaled_image.resize(
orig_input_size, PIL.Image.LANCZOS,
)
current_image = original_image
per_pass_factor = 2.0
current_w, current_h = current_image.width, current_image.height
for p in range(num_passes):
# Compute target size for this pass
if p == num_passes - 1:
# Last pass: go to exact target size
target_w = block_state.upscaled_width
target_h = block_state.upscaled_height
else:
target_w = int(current_w * per_pass_factor)
target_h = int(current_h * per_pass_factor)
# Upscale current image to target
pass_upscaled = current_image.resize((target_w, target_h), PIL.Image.LANCZOS)
block_state._current_pass_strength = strength_per_pass[p]
# ControlNet: use the current pass input as control image
ctrl_pil = None
if use_controlnet:
ctrl_pil = pass_upscaled.copy()
logger.info(
f"Progressive pass {p + 1}/{num_passes}: "
f"{current_w}x{current_h} -> {target_w}x{target_h} "
f"(strength={strength_per_pass[p]:.2f})"
)
decoded_np = self._run_single_pass(
components, block_state,
upscaled_image=pass_upscaled,
h=target_h,
w=target_w,
ctrl_pil=ctrl_pil,
use_controlnet=use_controlnet,
latent_tile_size=latent_tile_size,
latent_overlap=latent_overlap,
)
# Convert decoded to PIL for next pass
result_uint8 = (np.clip(decoded_np, 0, 1) * 255).astype(np.uint8)
current_image = PIL.Image.fromarray(result_uint8)
current_w, current_h = current_image.width, current_image.height
# --- Format output ---
h = block_state.upscaled_height
w = block_state.upscaled_width
result_uint8 = (np.clip(decoded_np, 0, 1) * 255).astype(np.uint8)
if output_type == "pil":
block_state.images = [PIL.Image.fromarray(result_uint8)]
elif output_type == "np":
block_state.images = [decoded_np]
elif output_type == "pt":
block_state.images = [torch.from_numpy(decoded_np).permute(2, 0, 1).unsqueeze(0)]
else:
block_state.images = [PIL.Image.fromarray(result_uint8)]
# --- Feature 4: Output metadata ---
total_time = time.time() - t_start
metadata = {
"input_size": orig_input_size,
"output_size": (w, h),
"upscale_factor": upscale_factor,
"num_passes": num_passes,
"strength_per_pass": strength_per_pass,
"total_time": total_time,
}
block_state.metadata = metadata
if return_metadata:
logger.info(
f"MultiDiffusion complete: {orig_input_size} -> ({w}, {h}), "
f"{num_passes} pass(es), {total_time:.1f}s"
)
self.set_block_state(state, block_state)
return components, state
# ============================================================
# modular_blocks
# ============================================================
"""Block composition for Modular SDXL Upscale."""
from diffusers.utils import logging
from diffusers.modular_pipelines.modular_pipeline import SequentialPipelineBlocks
from diffusers.modular_pipelines.modular_pipeline_utils import OutputParam
from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import (
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLInputStep,
)
logger = logging.get_logger(__name__)
class MultiDiffusionUpscaleBlocks(SequentialPipelineBlocks):
"""Modular pipeline blocks for tiled SDXL upscaling with MultiDiffusion.
Uses latent-space noise prediction blending across overlapping tiles for
seamless tiled upscaling at any resolution.
Block graph::
[0] text_encoder - SDXL TextEncoderStep (reused)
[1] upscale - Lanczos resize
[2] input - SDXL InputStep (reused)
[3] set_timesteps - SDXL Img2Img SetTimestepsStep (reused)
[4] multidiffusion - MultiDiffusion step
The MultiDiffusion step handles VAE encode, tiled denoise with blending,
and VAE decode internally, using VAE tiling for memory efficiency.
"""
block_classes = [
UltimateSDUpscaleTextEncoderStep,
UltimateSDUpscaleUpscaleStep,
StableDiffusionXLInputStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
UltimateSDUpscaleMultiDiffusionStep,
]
block_names = [
"text_encoder",
"upscale",
"input",
"set_timesteps",
"multidiffusion",
]
_workflow_map = {
"upscale": {"image": True, "prompt": True},
"upscale_controlnet": {"image": True, "control_image": True, "prompt": True},
}
@property
def description(self):
return (
"MultiDiffusion upscale pipeline for Stable Diffusion XL.\n"
"Upscales an input image and refines it using tiled denoising with "
"latent-space noise prediction blending. Produces seamless output at "
"any resolution without tile-boundary artifacts.\n"
"Supports optional ControlNet Tile conditioning for improved fidelity."
)
@property
def outputs(self):
return [OutputParam.template("images")]
# ============================================================
# modular_pipeline
# ============================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modular pipeline class for tiled SDXL upscaling.
Reuses ``StableDiffusionXLModularPipeline`` since all components (VAE, UNet,
text encoders, scheduler) are the same. The only addition is the
``default_blocks_name`` pointing to our custom block composition.
"""
from diffusers.modular_pipelines.stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline
class UltimateSDUpscaleModularPipeline(StableDiffusionXLModularPipeline):
"""A ModularPipeline for tiled SDXL upscaling.
Inherits all SDXL component properties (``vae_scale_factor``,
``default_sample_size``, etc.) and overrides the default blocks to use
the tiled upscaling block composition.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "UltimateSDUpscaleBlocks"