Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,654 Bytes
660c6ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
# Copyright 2025, DEVAIEXP Team, Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import torch
import numpy as np
from enum import Enum
from typing import List, Optional, Union
from diffusers.utils import logging
from diffusers.pipelines.z_image.pipeline_z_image import ZImagePipeline, calculate_shift
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
logger = logging.get_logger(__name__)
# Tiling Engine Helper Functions
def _adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1280):
width, height = image_size; aspect_ratio = width / height
if aspect_ratio > 1:
tile_width = min(width, max_tile_size); tile_height = min(int(tile_width / aspect_ratio), max_tile_size)
else:
tile_height = min(height, max_tile_size); tile_width = min(int(tile_height * aspect_ratio), max_tile_size)
return max(tile_width, base_tile_size), max(tile_height, base_tile_size)
def _calculate_tile_positions(image_dim: int, tile_dim: int, overlap: int) -> List[int]:
if image_dim <= tile_dim: return [0]
positions = []; current_pos = 0; stride = tile_dim - overlap
while True:
positions.append(current_pos)
if current_pos + tile_dim >= image_dim: break
current_pos += stride
if current_pos > image_dim - tile_dim: break
if positions[-1] + tile_dim < image_dim: positions.append(image_dim - tile_dim)
return sorted(list(set(positions)))
def _tile2pixel_indices(tile_row_pos, tile_col_pos, tile_width, tile_height, image_width, image_height):
px_row_init = tile_row_pos; px_col_init = tile_col_pos
px_row_end = min(px_row_init + tile_height, image_height)
px_col_end = min(px_col_init + tile_width, image_width)
return px_row_init, px_row_end, px_col_init, px_col_end
def _tile2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end, vae_scale_factor):
return px_row_init // vae_scale_factor, px_row_end // vae_scale_factor, px_col_init // vae_scale_factor, px_col_end // vae_scale_factor
def release_memory(device):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
class ZImageMoDTilingPipeline(ZImagePipeline):
class TileWeightingMethod(Enum):
COSINE = "Cosine"; GAUSSIAN = "Gaussian"
def _generate_gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype, sigma=0.4):
latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor
x, y = np.linspace(-1, 1, latent_width), np.linspace(-1, 1, latent_height)
xx, yy = np.meshgrid(x, y)
gaussian_weight_np = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))
weights_torch_f32 = torch.tensor(gaussian_weight_np, device=device, dtype=torch.float32)
weights_torch_target_dtype = weights_torch_f32.to(dtype)
return torch.tile(weights_torch_target_dtype, (nbatches, self.transformer.in_channels, 1, 1))
def _generate_cosine_weights(self, tile_width, tile_height, nbatches, device, dtype):
latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor
x, y = np.arange(latent_width), np.arange(latent_height)
mid_x, mid_y = (latent_width - 1) / 2, (latent_height - 1) / 2
x_probs, y_probs = np.cos(np.pi * (x - mid_x) / latent_width), np.cos(np.pi * (y - mid_y) / latent_height)
weights_np = np.outer(y_probs, x_probs)
weights_torch = torch.tensor(weights_np, device=device, dtype=dtype)
return torch.tile(weights_torch, (nbatches, self.transformer.in_channels, 1, 1))
def prepare_tiles_weights(self, y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, dtype):
tile_weights = np.empty((len(y_steps), len(x_steps)), dtype=object)
for row, y_start in enumerate(y_steps):
for col, x_start in enumerate(x_steps):
_, px_row_end, _, px_col_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height)
current_tile_h, current_tile_w = px_row_end - y_start, px_col_end - x_start
if tile_weighting_method == self.TileWeightingMethod.COSINE.value:
tile_weights[row, col] = self._generate_cosine_weights(current_tile_w, current_tile_h, batch_size, device, dtype)
else:
tile_weights[row, col] = self._generate_gaussian_weights(current_tile_w, current_tile_h, batch_size, device, dtype, sigma=tile_gaussian_sigma)
return tile_weights
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[List[str]]],
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 9,
guidance_scale: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
max_tile_size: int = 1024,
tile_overlap: int = 256,
tile_weighting_method: str = "Cosine",
tile_gaussian_sigma: float = 0.4,
guidance_scale_tiles: Optional[List[List[float]]] = None,
max_sequence_length: int = 512,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
# Handle negative_prompt from kwargs for backward compatibility if needed, but default to empty
negative_prompt = kwargs.get("negative_prompt", "")
device = self._execution_device; batch_size = 1
is_prompt_grid = isinstance(prompt, list) and all(isinstance(row, list) for row in prompt)
PIXEL_MULTIPLE = self.vae_scale_factor * 2 # 16
# Grid and Dimension Calculation
if is_prompt_grid:
grid_rows, grid_cols = len(prompt), len(prompt[0])
tile_width = (width + (grid_cols - 1) * tile_overlap) // grid_cols
tile_height = (height + (grid_rows - 1) * tile_overlap) // grid_rows
tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE
final_width = tile_width * grid_cols - (grid_cols - 1) * tile_overlap
final_height = tile_height * grid_rows - (grid_rows - 1) * tile_overlap
x_steps = [i * (tile_width - tile_overlap) for i in range(grid_cols)]
y_steps = [i * (tile_height - tile_overlap) for i in range(grid_rows)]
else:
final_width, final_height = width, height
tile_width, tile_height = _adaptive_tile_size((final_width, final_height), max_tile_size=max_tile_size)
tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE
y_steps = _calculate_tile_positions(final_height, tile_height, tile_overlap)
x_steps = _calculate_tile_positions(final_width, tile_width, tile_overlap)
grid_rows, grid_cols = len(y_steps), len(x_steps)
# Prompt Encoding
text_embeddings = []
for r in range(grid_rows):
row_embeddings = []
for c in range(grid_cols):
p = prompt[r][c] if is_prompt_grid else prompt
prompt_embeds, _ = self.encode_prompt(
prompt=p, do_classifier_free_guidance=False, device=device, max_sequence_length=max_sequence_length
)
row_embeddings.append({"prompt": prompt_embeds})
text_embeddings.append(row_embeddings)
# Latent and Scheduler Setup
num_latent_channels = self.transformer.in_channels
latents = self.prepare_latents(
batch_size, num_latent_channels, final_height, final_width, torch.float32, device, generator
)
image_seq_len = (tile_height // 2) * (tile_width // 2)
mu = calculate_shift(image_seq_len)
self.scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
timesteps = self.scheduler.timesteps
# Prepare Weights and Offload
tile_weights = self.prepare_tiles_weights(y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, latents.dtype)
self.text_encoder.to("cpu");
release_memory(device)
# Denoising Loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for t in timesteps:
denoised_step_canvas = torch.zeros_like(latents)
contributors = torch.zeros_like(latents)
for r, y_start in enumerate(y_steps):
for c, x_start in enumerate(x_steps):
px_r_init, px_r_end, px_c_init, px_c_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height)
r_init, r_end, c_init, c_end = _tile2latent_indices(px_r_init, px_r_end, px_c_init, px_c_end, self.vae_scale_factor)
tile_latents = latents[:, :, r_init:r_end, c_init:c_end]
latents_typed = tile_latents.to(self.transformer.dtype)
embeds = text_embeddings[r][c]
timestep_model_input = t.expand(latents_typed.shape[0])
timestep_norm = (1000 - timestep_model_input) / 1000
latent_model_input_list = list(latents_typed.unsqueeze(2).unbind(dim=0))
model_out_list = self.transformer(
latent_model_input_list,
timestep_norm,
embeds["prompt"],
)[0]
noise_pred_tile = model_out_list[0].float()
noise_pred_tile = -noise_pred_tile.squeeze(1)
denoised_step_canvas[:, :, r_init:r_end, c_init:c_end] += noise_pred_tile * tile_weights[r, c]
contributors[:, :, r_init:r_end, c_init:c_end] += tile_weights[r, c]
noise_pred_canvas = denoised_step_canvas / contributors
latents = self.scheduler.step(noise_pred_canvas.to(torch.float32), t, latents).prev_sample
progress_bar.update()
# Post-processing
if output_type == "latent":
image = latents
else:
self.vae.to(device)
latents = latents.to(self.vae.dtype)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
self.maybe_free_model_hooks();
return ZImagePipelineOutput(images=image) if return_dict else image |