akshan-main's picture
Upload block.py with huggingface_hub
115c3bc verified
"""MultiDiffusion tiled upscaling for Z-Image using Modular Diffusers.
Tiles the latent space, denoises each tile independently per timestep,
blends with cosine-ramp overlap weights, and applies one scheduler step
on the full blended prediction. Supports optional ControlNet conditioning,
progressive upscaling, auto-strength, and metadata output.
"""
import math
import time
from dataclasses import dataclass
import numpy as np
import PIL.Image
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, ZImageTransformer2DModel
from diffusers.models.controlnets import ZImageControlNetModel
from diffusers.modular_pipelines.modular_pipeline import (
ModularPipelineBlocks,
PipelineState,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
InputParam,
OutputParam,
)
from diffusers.modular_pipelines.z_image.encoders import (
ZImageTextEncoderStep,
retrieve_latents,
)
from diffusers.modular_pipelines.z_image.modular_pipeline import ZImageModularPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__)
# ============================================================
# Tiling utilities
# ============================================================
@dataclass
class LatentTileSpec:
y: int
x: int
h: int
w: int
def plan_latent_tiles(latent_h, latent_w, tile_size=64, overlap=8):
if tile_size <= 0:
raise ValueError(f"tile_size must be positive, got {tile_size}")
if overlap >= tile_size:
raise ValueError(f"overlap ({overlap}) must be less than tile_size ({tile_size})")
stride = tile_size - overlap
tiles = []
y = 0
while y < latent_h:
h = min(tile_size, latent_h - y)
if h < tile_size and y > 0:
y = max(0, latent_h - tile_size)
h = latent_h - y
x = 0
while x < latent_w:
w = min(tile_size, latent_w - x)
if w < tile_size and x > 0:
x = max(0, latent_w - tile_size)
w = latent_w - x
tiles.append(LatentTileSpec(y=y, x=x, h=h, w=w))
if x + w >= latent_w:
break
x += stride
if y + h >= latent_h:
break
y += stride
return tiles
def _make_cosine_tile_weight(
h, w, overlap, device, dtype, is_top=False, is_bottom=False, is_left=False, is_right=False
):
def _ramp(length, overlap_size, keep_start, keep_end):
ramp = torch.ones(length, device=device, dtype=dtype)
if overlap_size > 0 and length > 2 * overlap_size:
fade = 0.5 * (1.0 - torch.cos(torch.linspace(0, math.pi, overlap_size, device=device, dtype=dtype)))
if not keep_start:
ramp[:overlap_size] = fade
if not keep_end:
ramp[-overlap_size:] = fade.flip(0)
return ramp
w_h = _ramp(h, overlap, keep_start=is_top, keep_end=is_bottom)
w_w = _ramp(w, overlap, keep_start=is_left, keep_end=is_right)
return (w_h[:, None] * w_w[None, :]).unsqueeze(0).unsqueeze(0)
def _compute_auto_strength(upscale_factor, pass_index, num_passes):
if num_passes > 1:
return 0.4 if pass_index == 0 else 0.3
if upscale_factor <= 2.0:
return 0.4
elif upscale_factor <= 4.0:
return 0.25
else:
return 0.2
# ============================================================
# Upscale step
# ============================================================
class ZImageUpscaleStep(ModularPipelineBlocks):
model_name = "z-image"
@property
def description(self) -> str:
return "Upscale input image with Lanczos interpolation"
@property
def inputs(self) -> list[InputParam]:
return [
InputParam("image", required=True, type_hint=PIL.Image.Image),
InputParam("scale_factor", default=2.0, type_hint=float),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam("upscaled_image", type_hint=PIL.Image.Image),
OutputParam("height", type_hint=int),
OutputParam("width", type_hint=int),
]
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
image = block_state.image
scale = block_state.scale_factor
new_w = int(image.width * scale)
new_h = int(image.height * scale)
sf = components.vae_scale_factor_spatial
new_w = (new_w // sf) * sf
new_h = (new_h // sf) * sf
block_state.upscaled_image = image.resize((new_w, new_h), PIL.Image.LANCZOS)
block_state.height = new_h
block_state.width = new_w
self.set_block_state(state, block_state)
return components, state
# ============================================================
# MultiDiffusion denoise step
# ============================================================
class ZImageMultiDiffusionStep(ModularPipelineBlocks):
"""MultiDiffusion tiled denoising for Z-Image with optional ControlNet."""
model_name = "z-image"
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec("transformer", ZImageTransformer2DModel),
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8 * 2}),
default_creation_method="from_config",
),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0, "enabled": False}),
default_creation_method="from_config",
),
ComponentSpec("controlnet", ZImageControlNetModel),
]
@property
def description(self) -> str:
return (
"MultiDiffusion tiled denoising: encodes the full upscaled image, "
"denoises with overlapping latent tiles and cosine-weighted blending, "
"then decodes the result. Supports optional ControlNet conditioning."
)
@property
def inputs(self) -> list[InputParam]:
return [
InputParam("upscaled_image", required=True, type_hint=PIL.Image.Image),
InputParam("image", type_hint=PIL.Image.Image, description="Original input image (for progressive mode)."),
InputParam("height", required=True, type_hint=int),
InputParam("width", required=True, type_hint=int),
InputParam("scale_factor", default=2.0, type_hint=float),
InputParam("prompt_embeds", required=True),
InputParam("negative_prompt_embeds"),
InputParam("num_inference_steps", default=8, type_hint=int),
InputParam("strength", default=0.4, type_hint=float),
InputParam("tile_size", default=64, type_hint=int),
InputParam("tile_overlap", default=8, type_hint=int),
InputParam("generator"),
InputParam("output_type", default="pil", type_hint=str),
InputParam("control_image", description="Optional ControlNet conditioning image (PIL)."),
InputParam("controlnet_conditioning_scale", default=0.75, type_hint=float),
InputParam(
"progressive",
default=True,
type_hint=bool,
description="Split upscale_factor > 2 into multiple 2x passes.",
),
InputParam(
"auto_strength",
default=True,
type_hint=bool,
description="Auto-scale strength based on upscale factor and pass index.",
),
InputParam("return_metadata", default=False, type_hint=bool),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam("images"),
OutputParam("metadata", type_hint=dict),
]
def _vae_encode(self, components, image_pil, height, width, generator, device, vae_dtype):
"""VAE-encode a PIL image to latent space."""
image_tensor = components.image_processor.preprocess(image_pil, height=height, width=width)
image_tensor = image_tensor.to(device=device, dtype=vae_dtype)
image_latents = retrieve_latents(components.vae.encode(image_tensor), generator=generator)
image_latents = (image_latents - components.vae.config.shift_factor) * components.vae.config.scaling_factor
return image_latents
def _vae_decode(self, components, latents, vae_dtype, output_type):
"""VAE-decode latents to images."""
decode_latents = latents.to(vae_dtype)
decode_latents = decode_latents / components.vae.config.scaling_factor + components.vae.config.shift_factor
decoded = components.vae.decode(decode_latents, return_dict=False)[0]
return components.image_processor.postprocess(decoded, output_type=output_type)
def _prepare_control_latents(self, components, control_image, height, width, generator, device, vae_dtype):
"""VAE-encode a control image for ControlNet conditioning."""
if isinstance(control_image, PIL.Image.Image):
if control_image.size != (width, height):
control_image = control_image.resize((width, height), PIL.Image.LANCZOS)
ctrl_tensor = components.image_processor.preprocess(control_image, height=height, width=width)
else:
ctrl_tensor = control_image
ctrl_tensor = ctrl_tensor.to(device=device, dtype=vae_dtype)
ctrl_latents = retrieve_latents(components.vae.encode(ctrl_tensor), generator=generator, sample_mode="argmax")
ctrl_latents = (ctrl_latents - components.vae.config.shift_factor) * components.vae.config.scaling_factor
ctrl_latents = ctrl_latents.unsqueeze(2) # [B, C, 1, H, W]
# Pad channels if controlnet expects more
num_channels_latents = components.transformer.in_channels
if hasattr(components.controlnet, "config") and hasattr(components.controlnet.config, "control_in_dim"):
if num_channels_latents != components.controlnet.config.control_in_dim:
pad_channels = components.controlnet.config.control_in_dim - num_channels_latents
ctrl_latents = torch.cat(
[
ctrl_latents,
torch.zeros(
ctrl_latents.shape[0],
pad_channels,
*ctrl_latents.shape[2:],
).to(device=ctrl_latents.device, dtype=ctrl_latents.dtype),
],
dim=1,
)
return ctrl_latents
def _run_tile_transformer(
self,
components,
tile_latents,
t,
i,
num_inference_steps,
prompt_embeds,
negative_prompt_embeds,
dtype,
controlnet_cond_tile=None,
controlnet_conditioning_scale=0.75,
):
"""Run transformer (+ optional ControlNet) on a single tile."""
latent_input = tile_latents.unsqueeze(2).to(dtype)
latent_model_input = list(latent_input.unbind(dim=0))
timestep = t.expand(tile_latents.shape[0]).to(dtype)
timestep = (1000 - timestep) / 1000
guider_inputs = {"cap_feats": (prompt_embeds, negative_prompt_embeds)}
components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {}
for k, v in guider_state_batch.as_dict().items():
if k in guider_inputs:
if isinstance(v, torch.Tensor):
cond_kwargs[k] = v.to(dtype)
elif isinstance(v, list):
cond_kwargs[k] = [x.to(dtype) if isinstance(x, torch.Tensor) else x for x in v]
else:
cond_kwargs[k] = v
controlnet_block_samples = None
if controlnet_cond_tile is not None and getattr(components, "controlnet", None) is not None:
cap_feats_for_cn = cond_kwargs.get("cap_feats", prompt_embeds)
controlnet_block_samples = components.controlnet(
latent_model_input,
timestep,
cap_feats_for_cn,
controlnet_cond_tile,
conditioning_scale=controlnet_conditioning_scale,
)
transformer_kwargs = {"x": latent_model_input, "t": timestep, "return_dict": False, **cond_kwargs}
if controlnet_block_samples is not None:
transformer_kwargs["controlnet_block_samples"] = controlnet_block_samples
model_out_list = components.transformer(**transformer_kwargs)[0]
noise_pred = torch.stack(model_out_list, dim=0).squeeze(2)
guider_state_batch.noise_pred = -noise_pred
components.guider.cleanup_models(components.transformer)
return components.guider(guider_state)[0]
def _run_single_pass(
self,
components,
block_state,
upscaled_image,
h,
w,
control_image,
use_controlnet,
tile_size,
tile_overlap,
pass_strength,
):
"""Run one MultiDiffusion encode-denoise-decode pass, return decoded numpy."""
device = components._execution_device
vae_dtype = components.vae.dtype
dtype = components.transformer.dtype
generator = block_state.generator
if hasattr(components.vae, "enable_tiling"):
components.vae.enable_tiling()
image_latents = self._vae_encode(components, upscaled_image, h, w, generator, device, vae_dtype)
# ControlNet latents
full_control_latents = None
if use_controlnet and control_image is not None:
full_control_latents = self._prepare_control_latents(
components, control_image, h, w, generator, device, vae_dtype
)
# Timesteps with pass strength
num_inference_steps = block_state.num_inference_steps
components.scheduler.set_timesteps(num_inference_steps, device=device)
all_timesteps = components.scheduler.timesteps
init_timestep = min(int(num_inference_steps * pass_strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = all_timesteps[t_start:]
num_inf_steps = len(timesteps)
if num_inf_steps == 0:
latents = image_latents
else:
noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=image_latents.dtype)
latent_timestep = timesteps[:1].repeat(image_latents.shape[0])
latents = components.scheduler.scale_noise(image_latents, latent_timestep, noise)
latent_h, latent_w = latents.shape[2], latents.shape[3]
tile_specs = plan_latent_tiles(latent_h, latent_w, tile_size, tile_overlap)
logger.info(f"MultiDiffusion: {len(tile_specs)} tiles, latent {latent_w}x{latent_h}")
prompt_embeds = block_state.prompt_embeds
negative_prompt_embeds = getattr(block_state, "negative_prompt_embeds", None)
cn_scale = getattr(block_state, "controlnet_conditioning_scale", 0.75)
for i, t in enumerate(timesteps):
noise_pred_accum = torch.zeros_like(latents, dtype=torch.float32)
weight_accum = torch.zeros(1, 1, latent_h, latent_w, device=device, dtype=torch.float32)
for tile in tile_specs:
tile_latents = latents[:, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w].clone()
cn_tile = None
if use_controlnet and full_control_latents is not None:
cn_tile = full_control_latents[:, :, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w]
tile_noise_pred = self._run_tile_transformer(
components,
tile_latents,
t,
i,
num_inf_steps,
prompt_embeds,
negative_prompt_embeds,
dtype,
controlnet_cond_tile=cn_tile,
controlnet_conditioning_scale=cn_scale,
)
tile_weight = _make_cosine_tile_weight(
tile.h,
tile.w,
tile_overlap,
device,
torch.float32,
is_top=(tile.y == 0),
is_bottom=(tile.y + tile.h >= latent_h),
is_left=(tile.x == 0),
is_right=(tile.x + tile.w >= latent_w),
)
noise_pred_accum[:, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w] += (
tile_noise_pred.to(torch.float32) * tile_weight
)
weight_accum[:, :, tile.y : tile.y + tile.h, tile.x : tile.x + tile.w] += tile_weight
blended = noise_pred_accum / weight_accum.clamp(min=1e-6)
blended = torch.nan_to_num(blended, nan=0.0, posinf=0.0, neginf=0.0).to(latents.dtype)
latents = components.scheduler.step(blended.float(), t, latents.float(), return_dict=False)[0]
latents = latents.to(dtype=image_latents.dtype)
decoded = self._vae_decode(components, latents, vae_dtype, "np")
return decoded[0]
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
t_start = time.time()
output_type = block_state.output_type
tile_size = block_state.tile_size
tile_overlap = block_state.tile_overlap
upscale_factor = getattr(block_state, "scale_factor", 2.0)
progressive = getattr(block_state, "progressive", True)
auto_strength = getattr(block_state, "auto_strength", True)
return_metadata = getattr(block_state, "return_metadata", False)
user_strength = block_state.strength
# Progressive passes
if progressive and upscale_factor > 2.0:
num_passes = max(1, int(math.ceil(math.log2(upscale_factor))))
else:
num_passes = 1
# Strength per pass
strength_per_pass = []
for p in range(num_passes):
if auto_strength:
strength_per_pass.append(_compute_auto_strength(upscale_factor, p, num_passes))
else:
strength_per_pass.append(user_strength)
# ControlNet setup
control_image_raw = getattr(block_state, "control_image", None)
use_controlnet = False
if control_image_raw is not None:
if not hasattr(components, "controlnet") or components.controlnet is None:
raise ValueError("`control_image` provided but `controlnet` component is missing.")
use_controlnet = True
logger.info("MultiDiffusion: ControlNet enabled.")
orig_input_size = (block_state.upscaled_image.width, block_state.upscaled_image.height)
original_image = getattr(block_state, "image", None)
if num_passes == 1:
ctrl_pil = control_image_raw if use_controlnet else None
decoded_np = self._run_single_pass(
components,
block_state,
upscaled_image=block_state.upscaled_image,
h=block_state.height,
w=block_state.width,
control_image=ctrl_pil,
use_controlnet=use_controlnet,
tile_size=tile_size,
tile_overlap=tile_overlap,
pass_strength=strength_per_pass[0],
)
else:
if original_image is None:
original_image = block_state.upscaled_image.resize(
(int(block_state.width / upscale_factor), int(block_state.height / upscale_factor)),
PIL.Image.LANCZOS,
)
current_image = original_image
current_w, current_h = current_image.width, current_image.height
for p in range(num_passes):
if p == num_passes - 1:
target_w = block_state.width
target_h = block_state.height
else:
target_w = int(current_w * 2.0)
target_h = int(current_h * 2.0)
sf = components.vae_scale_factor_spatial
target_w = (target_w // sf) * sf
target_h = (target_h // sf) * sf
pass_upscaled = current_image.resize((target_w, target_h), PIL.Image.LANCZOS)
ctrl_pil = pass_upscaled.copy() if use_controlnet else None
logger.info(
f"Progressive pass {p + 1}/{num_passes}: "
f"{current_w}x{current_h} -> {target_w}x{target_h} "
f"(strength={strength_per_pass[p]:.2f})"
)
decoded_np = self._run_single_pass(
components,
block_state,
upscaled_image=pass_upscaled,
h=target_h,
w=target_w,
control_image=ctrl_pil,
use_controlnet=use_controlnet,
tile_size=tile_size,
tile_overlap=tile_overlap,
pass_strength=strength_per_pass[p],
)
result_uint8 = (np.clip(decoded_np, 0, 1) * 255).astype(np.uint8)
current_image = PIL.Image.fromarray(result_uint8)
current_w, current_h = current_image.width, current_image.height
# Format output
result_uint8 = (np.clip(decoded_np, 0, 1) * 255).astype(np.uint8)
if output_type == "pil":
block_state.images = [PIL.Image.fromarray(result_uint8)]
elif output_type == "np":
block_state.images = [decoded_np]
elif output_type == "pt":
block_state.images = [torch.from_numpy(decoded_np).permute(2, 0, 1).unsqueeze(0)]
else:
block_state.images = [PIL.Image.fromarray(result_uint8)]
# Metadata
total_time = time.time() - t_start
block_state.metadata = {
"input_size": orig_input_size,
"output_size": (block_state.width, block_state.height),
"upscale_factor": upscale_factor,
"num_passes": num_passes,
"strength_per_pass": strength_per_pass,
"total_time": total_time,
}
if return_metadata:
print(f" Input size: {orig_input_size}")
print(f" Output size: ({block_state.width}, {block_state.height})")
print(f" Upscale factor: {upscale_factor}")
print(f" Num passes: {num_passes}")
print(f" Strength per pass: {strength_per_pass}")
print(f" Total time: {total_time:.1f}s")
self.set_block_state(state, block_state)
return components, state
# ============================================================
# Assembled blocks
# ============================================================
class MultiDiffusionUpscaleBlocks(SequentialPipelineBlocks):
model_name = "z-image"
block_classes = [
ZImageTextEncoderStep,
ZImageUpscaleStep,
ZImageMultiDiffusionStep,
]
block_names = ["text_encoder", "upscale", "multidiffusion"]
@property
def description(self):
return (
"MultiDiffusion upscale pipeline for Z-Image.\n"
"1. Text encoding (Qwen3)\n"
"2. Lanczos upscale\n"
"3. MultiDiffusion tiled denoise + VAE decode"
)
@property
def outputs(self):
return [
OutputParam("images", description="The upscaled images."),
OutputParam("metadata", type_hint=dict, description="Generation metadata."),
]