| """MultiDiffusion tiled upscaling for Z-Image using Modular Diffusers. |
| |
| Tiles the latent space, denoises each tile independently per timestep, |
| blends with cosine-ramp overlap weights, and applies one scheduler step |
| on the full blended prediction. Supports optional ControlNet conditioning, |
| progressive upscaling, auto-strength, and metadata output. |
| """ |
|
|
| import math |
| import time |
| from dataclasses import dataclass |
|
|
| import numpy as np |
| import PIL.Image |
| import torch |
|
|
| from diffusers.configuration_utils import FrozenDict |
| from diffusers.guiders import ClassifierFreeGuidance |
| from diffusers.image_processor import VaeImageProcessor |
| from diffusers.models import AutoencoderKL, ZImageTransformer2DModel |
| from diffusers.models.controlnets import ZImageControlNetModel |
| from diffusers.modular_pipelines.modular_pipeline import ( |
| ModularPipelineBlocks, |
| PipelineState, |
| SequentialPipelineBlocks, |
| ) |
| from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| ComponentSpec, |
| InputParam, |
| OutputParam, |
| ) |
| from diffusers.modular_pipelines.z_image.encoders import ( |
| ZImageTextEncoderStep, |
| retrieve_latents, |
| ) |
| from diffusers.modular_pipelines.z_image.modular_pipeline import ZImageModularPipeline |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
| from diffusers.utils import logging |
| from diffusers.utils.torch_utils import randn_tensor |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class LatentTileSpec: |
| y: int |
| x: int |
| h: int |
| w: int |
|
|
|
|
| def plan_latent_tiles(latent_h, latent_w, tile_size=64, overlap=8): |
| if tile_size <= 0: |
| raise ValueError(f"tile_size must be positive, got {tile_size}") |
| if overlap >= tile_size: |
| raise ValueError(f"overlap ({overlap}) must be less than tile_size ({tile_size})") |
| stride = tile_size - overlap |
| tiles = [] |
| y = 0 |
| while y < latent_h: |
| h = min(tile_size, latent_h - y) |
| if h < tile_size and y > 0: |
| y = max(0, latent_h - tile_size) |
| h = latent_h - y |
| x = 0 |
| while x < latent_w: |
| w = min(tile_size, latent_w - x) |
| if w < tile_size and x > 0: |
| x = max(0, latent_w - tile_size) |
| w = latent_w - x |
| tiles.append(LatentTileSpec(y=y, x=x, h=h, w=w)) |
| if x + w >= latent_w: |
| break |
| x += stride |
| if y + h >= latent_h: |
| break |
| y += stride |
| return tiles |
|
|
|
|
| def _make_cosine_tile_weight( |
| h, w, overlap, device, dtype, is_top=False, is_bottom=False, is_left=False, is_right=False |
| ): |
| def _ramp(length, overlap_size, keep_start, keep_end): |
| ramp = torch.ones(length, device=device, dtype=dtype) |
| if overlap_size > 0 and length > 2 * overlap_size: |
| fade = 0.5 * (1.0 - torch.cos(torch.linspace(0, math.pi, overlap_size, device=device, dtype=dtype))) |
| if not keep_start: |
| ramp[:overlap_size] = fade |
| if not keep_end: |
| ramp[-overlap_size:] = fade.flip(0) |
| return ramp |
|
|
| w_h = _ramp(h, overlap, keep_start=is_top, keep_end=is_bottom) |
| w_w = _ramp(w, overlap, keep_start=is_left, keep_end=is_right) |
| return (w_h[:, None] * w_w[None, :]).unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def _compute_auto_strength(upscale_factor, pass_index, num_passes): |
| if num_passes > 1: |
| return 0.4 if pass_index == 0 else 0.3 |
| if upscale_factor <= 2.0: |
| return 0.4 |
| elif upscale_factor <= 4.0: |
| return 0.25 |
| else: |
| return 0.2 |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ZImageUpscaleStep(ModularPipelineBlocks): |
| model_name = "z-image" |
|
|
| @property |
| def description(self) -> str: |
| return "Upscale input image with Lanczos interpolation" |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam("image", required=True, type_hint=PIL.Image.Image), |
| InputParam("scale_factor", default=2.0, type_hint=float), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam("upscaled_image", type_hint=PIL.Image.Image), |
| OutputParam("height", type_hint=int), |
| OutputParam("width", type_hint=int), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
| image = block_state.image |
| scale = block_state.scale_factor |
| new_w = int(image.width * scale) |
| new_h = int(image.height * scale) |
| sf = components.vae_scale_factor_spatial |
| new_w = (new_w // sf) * sf |
| new_h = (new_h // sf) * sf |
| block_state.upscaled_image = image.resize((new_w, new_h), PIL.Image.LANCZOS) |
| block_state.height = new_h |
| block_state.width = new_w |
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ZImageMultiDiffusionStep(ModularPipelineBlocks): |
| """MultiDiffusion tiled denoising for Z-Image with optional ControlNet.""" |
|
|
| model_name = "z-image" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ |
| ComponentSpec("vae", AutoencoderKL), |
| ComponentSpec("transformer", ZImageTransformer2DModel), |
| ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), |
| ComponentSpec( |
| "image_processor", |
| VaeImageProcessor, |
| config=FrozenDict({"vae_scale_factor": 8 * 2}), |
| default_creation_method="from_config", |
| ), |
| ComponentSpec( |
| "guider", |
| ClassifierFreeGuidance, |
| config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), |
| default_creation_method="from_config", |
| ), |
| ComponentSpec("controlnet", ZImageControlNetModel), |
| ] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "MultiDiffusion tiled denoising: encodes the full upscaled image, " |
| "denoises with overlapping latent tiles and cosine-weighted blending, " |
| "then decodes the result. Supports optional ControlNet conditioning." |
| ) |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam("upscaled_image", required=True, type_hint=PIL.Image.Image), |
| InputParam("image", type_hint=PIL.Image.Image, description="Original input image (for progressive mode)."), |
| InputParam("height", required=True, type_hint=int), |
| InputParam("width", required=True, type_hint=int), |
| InputParam("scale_factor", default=2.0, type_hint=float), |
| InputParam("prompt_embeds", required=True), |
| InputParam("negative_prompt_embeds"), |
| InputParam("num_inference_steps", default=8, type_hint=int), |
| InputParam("strength", default=0.4, type_hint=float), |
| InputParam("tile_size", default=64, type_hint=int), |
| InputParam("tile_overlap", default=8, type_hint=int), |
| InputParam("generator"), |
| InputParam("output_type", default="pil", type_hint=str), |
| InputParam("control_image", description="Optional ControlNet conditioning image (PIL)."), |
| InputParam("controlnet_conditioning_scale", default=0.75, type_hint=float), |
| InputParam( |
| "progressive", |
| default=True, |
| type_hint=bool, |
| description="Split upscale_factor > 2 into multiple 2x passes.", |
| ), |
| InputParam( |
| "auto_strength", |
| default=True, |
| type_hint=bool, |
| description="Auto-scale strength based on upscale factor and pass index.", |
| ), |
| InputParam("return_metadata", default=False, type_hint=bool), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam("images"), |
| OutputParam("metadata", type_hint=dict), |
| ] |
|
|
| def _vae_encode(self, components, image_pil, height, width, generator, device, vae_dtype): |
| """VAE-encode a PIL image to latent space.""" |
| image_tensor = components.image_processor.preprocess(image_pil, height=height, width=width) |
| image_tensor = image_tensor.to(device=device, dtype=vae_dtype) |
| image_latents = retrieve_latents(components.vae.encode(image_tensor), generator=generator) |
| image_latents = (image_latents - components.vae.config.shift_factor) * components.vae.config.scaling_factor |
| return image_latents |
|
|
| def _vae_decode(self, components, latents, vae_dtype, output_type): |
| """VAE-decode latents to images.""" |
| decode_latents = latents.to(vae_dtype) |
| decode_latents = decode_latents / components.vae.config.scaling_factor + components.vae.config.shift_factor |
| decoded = components.vae.decode(decode_latents, return_dict=False)[0] |
| return components.image_processor.postprocess(decoded, output_type=output_type) |
|
|
| def _prepare_control_latents(self, components, control_image, height, width, generator, device, vae_dtype): |
| """VAE-encode a control image for ControlNet conditioning.""" |
| if isinstance(control_image, PIL.Image.Image): |
| if control_image.size != (width, height): |
| control_image = control_image.resize((width, height), PIL.Image.LANCZOS) |
| ctrl_tensor = components.image_processor.preprocess(control_image, height=height, width=width) |
| else: |
| ctrl_tensor = control_image |
| ctrl_tensor = ctrl_tensor.to(device=device, dtype=vae_dtype) |
| ctrl_latents = retrieve_latents(components.vae.encode(ctrl_tensor), generator=generator, sample_mode="argmax") |
| ctrl_latents = (ctrl_latents - components.vae.config.shift_factor) * components.vae.config.scaling_factor |
| ctrl_latents = ctrl_latents.unsqueeze(2) |
|
|
| |
| num_channels_latents = components.transformer.in_channels |
| if hasattr(components.controlnet, "config") and hasattr(components.controlnet.config, "control_in_dim"): |
| if num_channels_latents != components.controlnet.config.control_in_dim: |
| pad_channels = components.controlnet.config.control_in_dim - num_channels_latents |
| ctrl_latents = torch.cat( |
| [ |
| ctrl_latents, |
| torch.zeros( |
| ctrl_latents.shape[0], |
| pad_channels, |
| *ctrl_latents.shape[2:], |
| ).to(device=ctrl_latents.device, dtype=ctrl_latents.dtype), |
| ], |
| dim=1, |
| ) |
| return ctrl_latents |
|
|
| def _run_tile_transformer( |
| self, |
| components, |
| tile_latents, |
| t, |
| i, |
| num_inference_steps, |
| prompt_embeds, |
| negative_prompt_embeds, |
| dtype, |
| controlnet_cond_tile=None, |
| controlnet_conditioning_scale=0.75, |
| ): |
| """Run transformer (+ optional ControlNet) on a single tile.""" |
| latent_input = tile_latents.unsqueeze(2).to(dtype) |
| latent_model_input = list(latent_input.unbind(dim=0)) |
|
|
| timestep = t.expand(tile_latents.shape[0]).to(dtype) |
| timestep = (1000 - timestep) / 1000 |
|
|
| guider_inputs = {"cap_feats": (prompt_embeds, negative_prompt_embeds)} |
| components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) |
| guider_state = components.guider.prepare_inputs(guider_inputs) |
|
|
| for guider_state_batch in guider_state: |
| components.guider.prepare_models(components.transformer) |
|
|
| cond_kwargs = {} |
| for k, v in guider_state_batch.as_dict().items(): |
| if k in guider_inputs: |
| if isinstance(v, torch.Tensor): |
| cond_kwargs[k] = v.to(dtype) |
| elif isinstance(v, list): |
| cond_kwargs[k] = [x.to(dtype) if isinstance(x, torch.Tensor) else x for x in v] |
| else: |
| cond_kwargs[k] = v |
|
|
| controlnet_block_samples = None |
| if controlnet_cond_tile is not None and getattr(components, "controlnet", None) is not None: |
| cap_feats_for_cn = cond_kwargs.get("cap_feats", prompt_embeds) |
| controlnet_block_samples = components.controlnet( |
| latent_model_input, |
| timestep, |
| cap_feats_for_cn, |
| controlnet_cond_tile, |
| conditioning_scale=controlnet_conditioning_scale, |
| ) |
|
|
| transformer_kwargs = {"x": latent_model_input, "t": timestep, "return_dict": False, **cond_kwargs} |
| if controlnet_block_samples is not None: |
| transformer_kwargs["controlnet_block_samples"] = controlnet_block_samples |
|
|
| model_out_list = components.transformer(**transformer_kwargs)[0] |
| noise_pred = torch.stack(model_out_list, dim=0).squeeze(2) |
| guider_state_batch.noise_pred = -noise_pred |
| components.guider.cleanup_models(components.transformer) |
|
|
| return components.guider(guider_state)[0] |
|
|
| def _run_single_pass( |
| self, |
| components, |
| block_state, |
| upscaled_image, |
| h, |
| w, |
| control_image, |
| use_controlnet, |
| tile_size, |
| tile_overlap, |
| pass_strength, |
| ): |
| """Run one MultiDiffusion encode-denoise-decode pass, return decoded numpy.""" |
| device = components._execution_device |
| vae_dtype = components.vae.dtype |
| dtype = components.transformer.dtype |
| generator = block_state.generator |
|
|
| if hasattr(components.vae, "enable_tiling"): |
| components.vae.enable_tiling() |
|
|
| image_latents = self._vae_encode(components, upscaled_image, h, w, generator, device, vae_dtype) |
|
|
| |
| full_control_latents = None |
| if use_controlnet and control_image is not None: |
| full_control_latents = self._prepare_control_latents( |
| components, control_image, h, w, generator, device, vae_dtype |
| ) |
|
|
| |
| num_inference_steps = block_state.num_inference_steps |
| components.scheduler.set_timesteps(num_inference_steps, device=device) |
| all_timesteps = components.scheduler.timesteps |
| init_timestep = min(int(num_inference_steps * pass_strength), num_inference_steps) |
| t_start = max(num_inference_steps - init_timestep, 0) |
| timesteps = all_timesteps[t_start:] |
| num_inf_steps = len(timesteps) |
|
|
| if num_inf_steps == 0: |
| latents = image_latents |
| else: |
| noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=image_latents.dtype) |
| latent_timestep = timesteps[:1].repeat(image_latents.shape[0]) |
| latents = components.scheduler.scale_noise(image_latents, latent_timestep, noise) |
|
|
| latent_h, latent_w = latents.shape[2], latents.shape[3] |
| tile_specs = plan_latent_tiles(latent_h, latent_w, tile_size, tile_overlap) |
| logger.info(f"MultiDiffusion: {len(tile_specs)} tiles, latent {latent_w}x{latent_h}") |
|
|
| prompt_embeds = block_state.prompt_embeds |
| negative_prompt_embeds = getattr(block_state, "negative_prompt_embeds", None) |
| cn_scale = getattr(block_state, "controlnet_conditioning_scale", 0.75) |
|
|
| for i, t in enumerate(timesteps): |
| noise_pred_accum = torch.zeros_like(latents, dtype=torch.float32) |
| weight_accum = torch.zeros(1, 1, latent_h, latent_w, device=device, dtype=torch.float32) |
|
|
| for tile in tile_specs: |
| tile_latents = latents[:, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w].clone() |
|
|
| cn_tile = None |
| if use_controlnet and full_control_latents is not None: |
| cn_tile = full_control_latents[:, :, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w] |
|
|
| tile_noise_pred = self._run_tile_transformer( |
| components, |
| tile_latents, |
| t, |
| i, |
| num_inf_steps, |
| prompt_embeds, |
| negative_prompt_embeds, |
| dtype, |
| controlnet_cond_tile=cn_tile, |
| controlnet_conditioning_scale=cn_scale, |
| ) |
|
|
| tile_weight = _make_cosine_tile_weight( |
| tile.h, |
| tile.w, |
| tile_overlap, |
| device, |
| torch.float32, |
| is_top=(tile.y == 0), |
| is_bottom=(tile.y + tile.h >= latent_h), |
| is_left=(tile.x == 0), |
| is_right=(tile.x + tile.w >= latent_w), |
| ) |
|
|
| noise_pred_accum[:, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w] += ( |
| tile_noise_pred.to(torch.float32) * tile_weight |
| ) |
| weight_accum[:, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w] += tile_weight |
|
|
| blended = noise_pred_accum / weight_accum.clamp(min=1e-6) |
| blended = torch.nan_to_num(blended, nan=0.0, posinf=0.0, neginf=0.0).to(latents.dtype) |
| latents = components.scheduler.step(blended.float(), t, latents.float(), return_dict=False)[0] |
| latents = latents.to(dtype=image_latents.dtype) |
|
|
| decoded = self._vae_decode(components, latents, vae_dtype, "np") |
| return decoded[0] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
| t_start = time.time() |
|
|
| output_type = block_state.output_type |
| tile_size = block_state.tile_size |
| tile_overlap = block_state.tile_overlap |
|
|
| upscale_factor = getattr(block_state, "scale_factor", 2.0) |
| progressive = getattr(block_state, "progressive", True) |
| auto_strength = getattr(block_state, "auto_strength", True) |
| return_metadata = getattr(block_state, "return_metadata", False) |
| user_strength = block_state.strength |
|
|
| |
| if progressive and upscale_factor > 2.0: |
| num_passes = max(1, int(math.ceil(math.log2(upscale_factor)))) |
| else: |
| num_passes = 1 |
|
|
| |
| strength_per_pass = [] |
| for p in range(num_passes): |
| if auto_strength: |
| strength_per_pass.append(_compute_auto_strength(upscale_factor, p, num_passes)) |
| else: |
| strength_per_pass.append(user_strength) |
|
|
| |
| control_image_raw = getattr(block_state, "control_image", None) |
| use_controlnet = False |
| if control_image_raw is not None: |
| if not hasattr(components, "controlnet") or components.controlnet is None: |
| raise ValueError("`control_image` provided but `controlnet` component is missing.") |
| use_controlnet = True |
| logger.info("MultiDiffusion: ControlNet enabled.") |
|
|
| orig_input_size = (block_state.upscaled_image.width, block_state.upscaled_image.height) |
| original_image = getattr(block_state, "image", None) |
|
|
| if num_passes == 1: |
| ctrl_pil = control_image_raw if use_controlnet else None |
| decoded_np = self._run_single_pass( |
| components, |
| block_state, |
| upscaled_image=block_state.upscaled_image, |
| h=block_state.height, |
| w=block_state.width, |
| control_image=ctrl_pil, |
| use_controlnet=use_controlnet, |
| tile_size=tile_size, |
| tile_overlap=tile_overlap, |
| pass_strength=strength_per_pass[0], |
| ) |
| else: |
| if original_image is None: |
| original_image = block_state.upscaled_image.resize( |
| (int(block_state.width / upscale_factor), int(block_state.height / upscale_factor)), |
| PIL.Image.LANCZOS, |
| ) |
|
|
| current_image = original_image |
| current_w, current_h = current_image.width, current_image.height |
|
|
| for p in range(num_passes): |
| if p == num_passes - 1: |
| target_w = block_state.width |
| target_h = block_state.height |
| else: |
| target_w = int(current_w * 2.0) |
| target_h = int(current_h * 2.0) |
|
|
| sf = components.vae_scale_factor_spatial |
| target_w = (target_w // sf) * sf |
| target_h = (target_h // sf) * sf |
|
|
| pass_upscaled = current_image.resize((target_w, target_h), PIL.Image.LANCZOS) |
| ctrl_pil = pass_upscaled.copy() if use_controlnet else None |
|
|
| logger.info( |
| f"Progressive pass {p + 1}/{num_passes}: " |
| f"{current_w}x{current_h} -> {target_w}x{target_h} " |
| f"(strength={strength_per_pass[p]:.2f})" |
| ) |
|
|
| decoded_np = self._run_single_pass( |
| components, |
| block_state, |
| upscaled_image=pass_upscaled, |
| h=target_h, |
| w=target_w, |
| control_image=ctrl_pil, |
| use_controlnet=use_controlnet, |
| tile_size=tile_size, |
| tile_overlap=tile_overlap, |
| pass_strength=strength_per_pass[p], |
| ) |
|
|
| result_uint8 = (np.clip(decoded_np, 0, 1) * 255).astype(np.uint8) |
| current_image = PIL.Image.fromarray(result_uint8) |
| current_w, current_h = current_image.width, current_image.height |
|
|
| |
| result_uint8 = (np.clip(decoded_np, 0, 1) * 255).astype(np.uint8) |
| if output_type == "pil": |
| block_state.images = [PIL.Image.fromarray(result_uint8)] |
| elif output_type == "np": |
| block_state.images = [decoded_np] |
| elif output_type == "pt": |
| block_state.images = [torch.from_numpy(decoded_np).permute(2, 0, 1).unsqueeze(0)] |
| else: |
| block_state.images = [PIL.Image.fromarray(result_uint8)] |
|
|
| |
| total_time = time.time() - t_start |
| block_state.metadata = { |
| "input_size": orig_input_size, |
| "output_size": (block_state.width, block_state.height), |
| "upscale_factor": upscale_factor, |
| "num_passes": num_passes, |
| "strength_per_pass": strength_per_pass, |
| "total_time": total_time, |
| } |
|
|
| if return_metadata: |
| print(f" Input size: {orig_input_size}") |
| print(f" Output size: ({block_state.width}, {block_state.height})") |
| print(f" Upscale factor: {upscale_factor}") |
| print(f" Num passes: {num_passes}") |
| print(f" Strength per pass: {strength_per_pass}") |
| print(f" Total time: {total_time:.1f}s") |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
|
|
| class MultiDiffusionUpscaleBlocks(SequentialPipelineBlocks): |
| model_name = "z-image" |
| block_classes = [ |
| ZImageTextEncoderStep, |
| ZImageUpscaleStep, |
| ZImageMultiDiffusionStep, |
| ] |
| block_names = ["text_encoder", "upscale", "multidiffusion"] |
|
|
| @property |
| def description(self): |
| return ( |
| "MultiDiffusion upscale pipeline for Z-Image.\n" |
| "1. Text encoding (Qwen3)\n" |
| "2. Lanczos upscale\n" |
| "3. MultiDiffusion tiled denoise + VAE decode" |
| ) |
|
|
| @property |
| def outputs(self): |
| return [ |
| OutputParam("images", description="The upscaled images."), |
| OutputParam("metadata", type_hint=dict, description="Generation metadata."), |
| ] |
|
|