"""MultiDiffusion tiled upscaling for Z-Image using Modular Diffusers. Tiles the latent space, denoises each tile independently per timestep, blends with cosine-ramp overlap weights, and applies one scheduler step on the full blended prediction. Supports optional ControlNet conditioning, progressive upscaling, auto-strength, and metadata output. """ import math import time from dataclasses import dataclass import numpy as np import PIL.Image import torch from diffusers.configuration_utils import FrozenDict from diffusers.guiders import ClassifierFreeGuidance from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL, ZImageTransformer2DModel from diffusers.models.controlnets import ZImageControlNetModel from diffusers.modular_pipelines.modular_pipeline import ( ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks, ) from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, InputParam, OutputParam, ) from diffusers.modular_pipelines.z_image.encoders import ( ZImageTextEncoderStep, retrieve_latents, ) from diffusers.modular_pipelines.z_image.modular_pipeline import ZImageModularPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor logger = logging.get_logger(__name__) # ============================================================ # Tiling utilities # ============================================================ @dataclass class LatentTileSpec: y: int x: int h: int w: int def plan_latent_tiles(latent_h, latent_w, tile_size=64, overlap=8): if tile_size <= 0: raise ValueError(f"tile_size must be positive, got {tile_size}") if overlap >= tile_size: raise ValueError(f"overlap ({overlap}) must be less than tile_size ({tile_size})") stride = tile_size - overlap tiles = [] y = 0 while y < latent_h: h = min(tile_size, latent_h - y) if h < tile_size and y > 0: y = max(0, latent_h - tile_size) h = latent_h - y x = 0 while x < latent_w: w = min(tile_size, latent_w - x) if w < tile_size and x > 0: x = max(0, latent_w - tile_size) w = latent_w - x tiles.append(LatentTileSpec(y=y, x=x, h=h, w=w)) if x + w >= latent_w: break x += stride if y + h >= latent_h: break y += stride return tiles def _make_cosine_tile_weight( h, w, overlap, device, dtype, is_top=False, is_bottom=False, is_left=False, is_right=False ): def _ramp(length, overlap_size, keep_start, keep_end): ramp = torch.ones(length, device=device, dtype=dtype) if overlap_size > 0 and length > 2 * overlap_size: fade = 0.5 * (1.0 - torch.cos(torch.linspace(0, math.pi, overlap_size, device=device, dtype=dtype))) if not keep_start: ramp[:overlap_size] = fade if not keep_end: ramp[-overlap_size:] = fade.flip(0) return ramp w_h = _ramp(h, overlap, keep_start=is_top, keep_end=is_bottom) w_w = _ramp(w, overlap, keep_start=is_left, keep_end=is_right) return (w_h[:, None] * w_w[None, :]).unsqueeze(0).unsqueeze(0) def _compute_auto_strength(upscale_factor, pass_index, num_passes): if num_passes > 1: return 0.4 if pass_index == 0 else 0.3 if upscale_factor <= 2.0: return 0.4 elif upscale_factor <= 4.0: return 0.25 else: return 0.2 # ============================================================ # Upscale step # ============================================================ class ZImageUpscaleStep(ModularPipelineBlocks): model_name = "z-image" @property def description(self) -> str: return "Upscale input image with Lanczos interpolation" @property def inputs(self) -> list[InputParam]: return [ InputParam("image", required=True, type_hint=PIL.Image.Image), InputParam("scale_factor", default=2.0, type_hint=float), ] @property def intermediate_outputs(self) -> list[OutputParam]: return [ OutputParam("upscaled_image", type_hint=PIL.Image.Image), OutputParam("height", type_hint=int), OutputParam("width", type_hint=int), ] @torch.no_grad() def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) image = block_state.image scale = block_state.scale_factor new_w = int(image.width * scale) new_h = int(image.height * scale) sf = components.vae_scale_factor_spatial new_w = (new_w // sf) * sf new_h = (new_h // sf) * sf block_state.upscaled_image = image.resize((new_w, new_h), PIL.Image.LANCZOS) block_state.height = new_h block_state.width = new_w self.set_block_state(state, block_state) return components, state # ============================================================ # MultiDiffusion denoise step # ============================================================ class ZImageMultiDiffusionStep(ModularPipelineBlocks): """MultiDiffusion tiled denoising for Z-Image with optional ControlNet.""" model_name = "z-image" @property def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec("transformer", ZImageTransformer2DModel), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), ComponentSpec( "image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8 * 2}), default_creation_method="from_config", ), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), default_creation_method="from_config", ), ComponentSpec("controlnet", ZImageControlNetModel), ] @property def description(self) -> str: return ( "MultiDiffusion tiled denoising: encodes the full upscaled image, " "denoises with overlapping latent tiles and cosine-weighted blending, " "then decodes the result. Supports optional ControlNet conditioning." ) @property def inputs(self) -> list[InputParam]: return [ InputParam("upscaled_image", required=True, type_hint=PIL.Image.Image), InputParam("image", type_hint=PIL.Image.Image, description="Original input image (for progressive mode)."), InputParam("height", required=True, type_hint=int), InputParam("width", required=True, type_hint=int), InputParam("scale_factor", default=2.0, type_hint=float), InputParam("prompt_embeds", required=True), InputParam("negative_prompt_embeds"), InputParam("num_inference_steps", default=8, type_hint=int), InputParam("strength", default=0.4, type_hint=float), InputParam("tile_size", default=64, type_hint=int), InputParam("tile_overlap", default=8, type_hint=int), InputParam("generator"), InputParam("output_type", default="pil", type_hint=str), InputParam("control_image", description="Optional ControlNet conditioning image (PIL)."), InputParam("controlnet_conditioning_scale", default=0.75, type_hint=float), InputParam( "progressive", default=True, type_hint=bool, description="Split upscale_factor > 2 into multiple 2x passes.", ), InputParam( "auto_strength", default=True, type_hint=bool, description="Auto-scale strength based on upscale factor and pass index.", ), InputParam("return_metadata", default=False, type_hint=bool), ] @property def intermediate_outputs(self) -> list[OutputParam]: return [ OutputParam("images"), OutputParam("metadata", type_hint=dict), ] def _vae_encode(self, components, image_pil, height, width, generator, device, vae_dtype): """VAE-encode a PIL image to latent space.""" image_tensor = components.image_processor.preprocess(image_pil, height=height, width=width) image_tensor = image_tensor.to(device=device, dtype=vae_dtype) image_latents = retrieve_latents(components.vae.encode(image_tensor), generator=generator) image_latents = (image_latents - components.vae.config.shift_factor) * components.vae.config.scaling_factor return image_latents def _vae_decode(self, components, latents, vae_dtype, output_type): """VAE-decode latents to images.""" decode_latents = latents.to(vae_dtype) decode_latents = decode_latents / components.vae.config.scaling_factor + components.vae.config.shift_factor decoded = components.vae.decode(decode_latents, return_dict=False)[0] return components.image_processor.postprocess(decoded, output_type=output_type) def _prepare_control_latents(self, components, control_image, height, width, generator, device, vae_dtype): """VAE-encode a control image for ControlNet conditioning.""" if isinstance(control_image, PIL.Image.Image): if control_image.size != (width, height): control_image = control_image.resize((width, height), PIL.Image.LANCZOS) ctrl_tensor = components.image_processor.preprocess(control_image, height=height, width=width) else: ctrl_tensor = control_image ctrl_tensor = ctrl_tensor.to(device=device, dtype=vae_dtype) ctrl_latents = retrieve_latents(components.vae.encode(ctrl_tensor), generator=generator, sample_mode="argmax") ctrl_latents = (ctrl_latents - components.vae.config.shift_factor) * components.vae.config.scaling_factor ctrl_latents = ctrl_latents.unsqueeze(2) # [B, C, 1, H, W] # Pad channels if controlnet expects more num_channels_latents = components.transformer.in_channels if hasattr(components.controlnet, "config") and hasattr(components.controlnet.config, "control_in_dim"): if num_channels_latents != components.controlnet.config.control_in_dim: pad_channels = components.controlnet.config.control_in_dim - num_channels_latents ctrl_latents = torch.cat( [ ctrl_latents, torch.zeros( ctrl_latents.shape[0], pad_channels, *ctrl_latents.shape[2:], ).to(device=ctrl_latents.device, dtype=ctrl_latents.dtype), ], dim=1, ) return ctrl_latents def _run_tile_transformer( self, components, tile_latents, t, i, num_inference_steps, prompt_embeds, negative_prompt_embeds, dtype, controlnet_cond_tile=None, controlnet_conditioning_scale=0.75, ): """Run transformer (+ optional ControlNet) on a single tile.""" latent_input = tile_latents.unsqueeze(2).to(dtype) latent_model_input = list(latent_input.unbind(dim=0)) timestep = t.expand(tile_latents.shape[0]).to(dtype) timestep = (1000 - timestep) / 1000 guider_inputs = {"cap_feats": (prompt_embeds, negative_prompt_embeds)} components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) cond_kwargs = {} for k, v in guider_state_batch.as_dict().items(): if k in guider_inputs: if isinstance(v, torch.Tensor): cond_kwargs[k] = v.to(dtype) elif isinstance(v, list): cond_kwargs[k] = [x.to(dtype) if isinstance(x, torch.Tensor) else x for x in v] else: cond_kwargs[k] = v controlnet_block_samples = None if controlnet_cond_tile is not None and getattr(components, "controlnet", None) is not None: cap_feats_for_cn = cond_kwargs.get("cap_feats", prompt_embeds) controlnet_block_samples = components.controlnet( latent_model_input, timestep, cap_feats_for_cn, controlnet_cond_tile, conditioning_scale=controlnet_conditioning_scale, ) transformer_kwargs = {"x": latent_model_input, "t": timestep, "return_dict": False, **cond_kwargs} if controlnet_block_samples is not None: transformer_kwargs["controlnet_block_samples"] = controlnet_block_samples model_out_list = components.transformer(**transformer_kwargs)[0] noise_pred = torch.stack(model_out_list, dim=0).squeeze(2) guider_state_batch.noise_pred = -noise_pred components.guider.cleanup_models(components.transformer) return components.guider(guider_state)[0] def _run_single_pass( self, components, block_state, upscaled_image, h, w, control_image, use_controlnet, tile_size, tile_overlap, pass_strength, ): """Run one MultiDiffusion encode-denoise-decode pass, return decoded numpy.""" device = components._execution_device vae_dtype = components.vae.dtype dtype = components.transformer.dtype generator = block_state.generator if hasattr(components.vae, "enable_tiling"): components.vae.enable_tiling() image_latents = self._vae_encode(components, upscaled_image, h, w, generator, device, vae_dtype) # ControlNet latents full_control_latents = None if use_controlnet and control_image is not None: full_control_latents = self._prepare_control_latents( components, control_image, h, w, generator, device, vae_dtype ) # Timesteps with pass strength num_inference_steps = block_state.num_inference_steps components.scheduler.set_timesteps(num_inference_steps, device=device) all_timesteps = components.scheduler.timesteps init_timestep = min(int(num_inference_steps * pass_strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = all_timesteps[t_start:] num_inf_steps = len(timesteps) if num_inf_steps == 0: latents = image_latents else: noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=image_latents.dtype) latent_timestep = timesteps[:1].repeat(image_latents.shape[0]) latents = components.scheduler.scale_noise(image_latents, latent_timestep, noise) latent_h, latent_w = latents.shape[2], latents.shape[3] tile_specs = plan_latent_tiles(latent_h, latent_w, tile_size, tile_overlap) logger.info(f"MultiDiffusion: {len(tile_specs)} tiles, latent {latent_w}x{latent_h}") prompt_embeds = block_state.prompt_embeds negative_prompt_embeds = getattr(block_state, "negative_prompt_embeds", None) cn_scale = getattr(block_state, "controlnet_conditioning_scale", 0.75) for i, t in enumerate(timesteps): noise_pred_accum = torch.zeros_like(latents, dtype=torch.float32) weight_accum = torch.zeros(1, 1, latent_h, latent_w, device=device, dtype=torch.float32) for tile in tile_specs: tile_latents = latents[:, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w].clone() cn_tile = None if use_controlnet and full_control_latents is not None: cn_tile = full_control_latents[:, :, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w] tile_noise_pred = self._run_tile_transformer( components, tile_latents, t, i, num_inf_steps, prompt_embeds, negative_prompt_embeds, dtype, controlnet_cond_tile=cn_tile, controlnet_conditioning_scale=cn_scale, ) tile_weight = _make_cosine_tile_weight( tile.h, tile.w, tile_overlap, device, torch.float32, is_top=(tile.y == 0), is_bottom=(tile.y + tile.h >= latent_h), is_left=(tile.x == 0), is_right=(tile.x + tile.w >= latent_w), ) noise_pred_accum[:, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w] += ( tile_noise_pred.to(torch.float32) * tile_weight ) weight_accum[:, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w] += tile_weight blended = noise_pred_accum / weight_accum.clamp(min=1e-6) blended = torch.nan_to_num(blended, nan=0.0, posinf=0.0, neginf=0.0).to(latents.dtype) latents = components.scheduler.step(blended.float(), t, latents.float(), return_dict=False)[0] latents = latents.to(dtype=image_latents.dtype) decoded = self._vae_decode(components, latents, vae_dtype, "np") return decoded[0] @torch.no_grad() def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) t_start = time.time() output_type = block_state.output_type tile_size = block_state.tile_size tile_overlap = block_state.tile_overlap upscale_factor = getattr(block_state, "scale_factor", 2.0) progressive = getattr(block_state, "progressive", True) auto_strength = getattr(block_state, "auto_strength", True) return_metadata = getattr(block_state, "return_metadata", False) user_strength = block_state.strength # Progressive passes if progressive and upscale_factor > 2.0: num_passes = max(1, int(math.ceil(math.log2(upscale_factor)))) else: num_passes = 1 # Strength per pass strength_per_pass = [] for p in range(num_passes): if auto_strength: strength_per_pass.append(_compute_auto_strength(upscale_factor, p, num_passes)) else: strength_per_pass.append(user_strength) # ControlNet setup control_image_raw = getattr(block_state, "control_image", None) use_controlnet = False if control_image_raw is not None: if not hasattr(components, "controlnet") or components.controlnet is None: raise ValueError("`control_image` provided but `controlnet` component is missing.") use_controlnet = True logger.info("MultiDiffusion: ControlNet enabled.") orig_input_size = (block_state.upscaled_image.width, block_state.upscaled_image.height) original_image = getattr(block_state, "image", None) if num_passes == 1: ctrl_pil = control_image_raw if use_controlnet else None decoded_np = self._run_single_pass( components, block_state, upscaled_image=block_state.upscaled_image, h=block_state.height, w=block_state.width, control_image=ctrl_pil, use_controlnet=use_controlnet, tile_size=tile_size, tile_overlap=tile_overlap, pass_strength=strength_per_pass[0], ) else: if original_image is None: original_image = block_state.upscaled_image.resize( (int(block_state.width / upscale_factor), int(block_state.height / upscale_factor)), PIL.Image.LANCZOS, ) current_image = original_image current_w, current_h = current_image.width, current_image.height for p in range(num_passes): if p == num_passes - 1: target_w = block_state.width target_h = block_state.height else: target_w = int(current_w * 2.0) target_h = int(current_h * 2.0) sf = components.vae_scale_factor_spatial target_w = (target_w // sf) * sf target_h = (target_h // sf) * sf pass_upscaled = current_image.resize((target_w, target_h), PIL.Image.LANCZOS) ctrl_pil = pass_upscaled.copy() if use_controlnet else None logger.info( f"Progressive pass {p + 1}/{num_passes}: " f"{current_w}x{current_h} -> {target_w}x{target_h} " f"(strength={strength_per_pass[p]:.2f})" ) decoded_np = self._run_single_pass( components, block_state, upscaled_image=pass_upscaled, h=target_h, w=target_w, control_image=ctrl_pil, use_controlnet=use_controlnet, tile_size=tile_size, tile_overlap=tile_overlap, pass_strength=strength_per_pass[p], ) result_uint8 = (np.clip(decoded_np, 0, 1) * 255).astype(np.uint8) current_image = PIL.Image.fromarray(result_uint8) current_w, current_h = current_image.width, current_image.height # Format output result_uint8 = (np.clip(decoded_np, 0, 1) * 255).astype(np.uint8) if output_type == "pil": block_state.images = [PIL.Image.fromarray(result_uint8)] elif output_type == "np": block_state.images = [decoded_np] elif output_type == "pt": block_state.images = [torch.from_numpy(decoded_np).permute(2, 0, 1).unsqueeze(0)] else: block_state.images = [PIL.Image.fromarray(result_uint8)] # Metadata total_time = time.time() - t_start block_state.metadata = { "input_size": orig_input_size, "output_size": (block_state.width, block_state.height), "upscale_factor": upscale_factor, "num_passes": num_passes, "strength_per_pass": strength_per_pass, "total_time": total_time, } if return_metadata: print(f" Input size: {orig_input_size}") print(f" Output size: ({block_state.width}, {block_state.height})") print(f" Upscale factor: {upscale_factor}") print(f" Num passes: {num_passes}") print(f" Strength per pass: {strength_per_pass}") print(f" Total time: {total_time:.1f}s") self.set_block_state(state, block_state) return components, state # ============================================================ # Assembled blocks # ============================================================ class MultiDiffusionUpscaleBlocks(SequentialPipelineBlocks): model_name = "z-image" block_classes = [ ZImageTextEncoderStep, ZImageUpscaleStep, ZImageMultiDiffusionStep, ] block_names = ["text_encoder", "upscale", "multidiffusion"] @property def description(self): return ( "MultiDiffusion upscale pipeline for Z-Image.\n" "1. Text encoding (Qwen3)\n" "2. Lanczos upscale\n" "3. MultiDiffusion tiled denoise + VAE decode" ) @property def outputs(self): return [ OutputParam("images", description="The upscaled images."), OutputParam("metadata", type_hint=dict, description="Generation metadata."), ]