z-image-panorama / pipeline_z_image_mod.py
elismasilva's picture
first commit
660c6ca
# Copyright 2025, DEVAIEXP Team, Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import torch
import numpy as np
from enum import Enum
from typing import List, Optional, Union
from diffusers.utils import logging
from diffusers.pipelines.z_image.pipeline_z_image import ZImagePipeline, calculate_shift
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
logger = logging.get_logger(__name__)
# Tiling Engine Helper Functions
def _adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1280):
width, height = image_size; aspect_ratio = width / height
if aspect_ratio > 1:
tile_width = min(width, max_tile_size); tile_height = min(int(tile_width / aspect_ratio), max_tile_size)
else:
tile_height = min(height, max_tile_size); tile_width = min(int(tile_height * aspect_ratio), max_tile_size)
return max(tile_width, base_tile_size), max(tile_height, base_tile_size)
def _calculate_tile_positions(image_dim: int, tile_dim: int, overlap: int) -> List[int]:
if image_dim <= tile_dim: return [0]
positions = []; current_pos = 0; stride = tile_dim - overlap
while True:
positions.append(current_pos)
if current_pos + tile_dim >= image_dim: break
current_pos += stride
if current_pos > image_dim - tile_dim: break
if positions[-1] + tile_dim < image_dim: positions.append(image_dim - tile_dim)
return sorted(list(set(positions)))
def _tile2pixel_indices(tile_row_pos, tile_col_pos, tile_width, tile_height, image_width, image_height):
px_row_init = tile_row_pos; px_col_init = tile_col_pos
px_row_end = min(px_row_init + tile_height, image_height)
px_col_end = min(px_col_init + tile_width, image_width)
return px_row_init, px_row_end, px_col_init, px_col_end
def _tile2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end, vae_scale_factor):
return px_row_init // vae_scale_factor, px_row_end // vae_scale_factor, px_col_init // vae_scale_factor, px_col_end // vae_scale_factor
def release_memory(device):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
class ZImageMoDTilingPipeline(ZImagePipeline):
class TileWeightingMethod(Enum):
COSINE = "Cosine"; GAUSSIAN = "Gaussian"
def _generate_gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype, sigma=0.4):
latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor
x, y = np.linspace(-1, 1, latent_width), np.linspace(-1, 1, latent_height)
xx, yy = np.meshgrid(x, y)
gaussian_weight_np = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))
weights_torch_f32 = torch.tensor(gaussian_weight_np, device=device, dtype=torch.float32)
weights_torch_target_dtype = weights_torch_f32.to(dtype)
return torch.tile(weights_torch_target_dtype, (nbatches, self.transformer.in_channels, 1, 1))
def _generate_cosine_weights(self, tile_width, tile_height, nbatches, device, dtype):
latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor
x, y = np.arange(latent_width), np.arange(latent_height)
mid_x, mid_y = (latent_width - 1) / 2, (latent_height - 1) / 2
x_probs, y_probs = np.cos(np.pi * (x - mid_x) / latent_width), np.cos(np.pi * (y - mid_y) / latent_height)
weights_np = np.outer(y_probs, x_probs)
weights_torch = torch.tensor(weights_np, device=device, dtype=dtype)
return torch.tile(weights_torch, (nbatches, self.transformer.in_channels, 1, 1))
def prepare_tiles_weights(self, y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, dtype):
tile_weights = np.empty((len(y_steps), len(x_steps)), dtype=object)
for row, y_start in enumerate(y_steps):
for col, x_start in enumerate(x_steps):
_, px_row_end, _, px_col_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height)
current_tile_h, current_tile_w = px_row_end - y_start, px_col_end - x_start
if tile_weighting_method == self.TileWeightingMethod.COSINE.value:
tile_weights[row, col] = self._generate_cosine_weights(current_tile_w, current_tile_h, batch_size, device, dtype)
else:
tile_weights[row, col] = self._generate_gaussian_weights(current_tile_w, current_tile_h, batch_size, device, dtype, sigma=tile_gaussian_sigma)
return tile_weights
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[List[str]]],
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 9,
guidance_scale: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
max_tile_size: int = 1024,
tile_overlap: int = 256,
tile_weighting_method: str = "Cosine",
tile_gaussian_sigma: float = 0.4,
guidance_scale_tiles: Optional[List[List[float]]] = None,
max_sequence_length: int = 512,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
# Handle negative_prompt from kwargs for backward compatibility if needed, but default to empty
negative_prompt = kwargs.get("negative_prompt", "")
device = self._execution_device; batch_size = 1
is_prompt_grid = isinstance(prompt, list) and all(isinstance(row, list) for row in prompt)
PIXEL_MULTIPLE = self.vae_scale_factor * 2 # 16
# Grid and Dimension Calculation
if is_prompt_grid:
grid_rows, grid_cols = len(prompt), len(prompt[0])
tile_width = (width + (grid_cols - 1) * tile_overlap) // grid_cols
tile_height = (height + (grid_rows - 1) * tile_overlap) // grid_rows
tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE
final_width = tile_width * grid_cols - (grid_cols - 1) * tile_overlap
final_height = tile_height * grid_rows - (grid_rows - 1) * tile_overlap
x_steps = [i * (tile_width - tile_overlap) for i in range(grid_cols)]
y_steps = [i * (tile_height - tile_overlap) for i in range(grid_rows)]
else:
final_width, final_height = width, height
tile_width, tile_height = _adaptive_tile_size((final_width, final_height), max_tile_size=max_tile_size)
tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE
y_steps = _calculate_tile_positions(final_height, tile_height, tile_overlap)
x_steps = _calculate_tile_positions(final_width, tile_width, tile_overlap)
grid_rows, grid_cols = len(y_steps), len(x_steps)
# Prompt Encoding
text_embeddings = []
for r in range(grid_rows):
row_embeddings = []
for c in range(grid_cols):
p = prompt[r][c] if is_prompt_grid else prompt
prompt_embeds, _ = self.encode_prompt(
prompt=p, do_classifier_free_guidance=False, device=device, max_sequence_length=max_sequence_length
)
row_embeddings.append({"prompt": prompt_embeds})
text_embeddings.append(row_embeddings)
# Latent and Scheduler Setup
num_latent_channels = self.transformer.in_channels
latents = self.prepare_latents(
batch_size, num_latent_channels, final_height, final_width, torch.float32, device, generator
)
image_seq_len = (tile_height // 2) * (tile_width // 2)
mu = calculate_shift(image_seq_len)
self.scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
timesteps = self.scheduler.timesteps
# Prepare Weights and Offload
tile_weights = self.prepare_tiles_weights(y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, latents.dtype)
self.text_encoder.to("cpu");
release_memory(device)
# Denoising Loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for t in timesteps:
denoised_step_canvas = torch.zeros_like(latents)
contributors = torch.zeros_like(latents)
for r, y_start in enumerate(y_steps):
for c, x_start in enumerate(x_steps):
px_r_init, px_r_end, px_c_init, px_c_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height)
r_init, r_end, c_init, c_end = _tile2latent_indices(px_r_init, px_r_end, px_c_init, px_c_end, self.vae_scale_factor)
tile_latents = latents[:, :, r_init:r_end, c_init:c_end]
latents_typed = tile_latents.to(self.transformer.dtype)
embeds = text_embeddings[r][c]
timestep_model_input = t.expand(latents_typed.shape[0])
timestep_norm = (1000 - timestep_model_input) / 1000
latent_model_input_list = list(latents_typed.unsqueeze(2).unbind(dim=0))
model_out_list = self.transformer(
latent_model_input_list,
timestep_norm,
embeds["prompt"],
)[0]
noise_pred_tile = model_out_list[0].float()
noise_pred_tile = -noise_pred_tile.squeeze(1)
denoised_step_canvas[:, :, r_init:r_end, c_init:c_end] += noise_pred_tile * tile_weights[r, c]
contributors[:, :, r_init:r_end, c_init:c_end] += tile_weights[r, c]
noise_pred_canvas = denoised_step_canvas / contributors
latents = self.scheduler.step(noise_pred_canvas.to(torch.float32), t, latents).prev_sample
progress_bar.update()
# Post-processing
if output_type == "latent":
image = latents
else:
self.vae.to(device)
latents = latents.to(self.vae.dtype)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
self.maybe_free_model_hooks();
return ZImagePipelineOutput(images=image) if return_dict else image