File size: 11,654 Bytes
660c6ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Copyright 2025, DEVAIEXP Team, Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import torch
import numpy as np
from enum import Enum
from typing import List, Optional, Union

from diffusers.utils import logging
from diffusers.pipelines.z_image.pipeline_z_image import ZImagePipeline, calculate_shift
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput

logger = logging.get_logger(__name__)

# Tiling Engine Helper Functions
def _adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1280):
    width, height = image_size; aspect_ratio = width / height
    if aspect_ratio > 1:
        tile_width = min(width, max_tile_size); tile_height = min(int(tile_width / aspect_ratio), max_tile_size)
    else:
        tile_height = min(height, max_tile_size); tile_width = min(int(tile_height * aspect_ratio), max_tile_size)
    return max(tile_width, base_tile_size), max(tile_height, base_tile_size)

def _calculate_tile_positions(image_dim: int, tile_dim: int, overlap: int) -> List[int]:
    if image_dim <= tile_dim: return [0]
    positions = []; current_pos = 0; stride = tile_dim - overlap
    while True:
        positions.append(current_pos)
        if current_pos + tile_dim >= image_dim: break
        current_pos += stride
        if current_pos > image_dim - tile_dim: break
    if positions[-1] + tile_dim < image_dim: positions.append(image_dim - tile_dim)
    return sorted(list(set(positions)))

def _tile2pixel_indices(tile_row_pos, tile_col_pos, tile_width, tile_height, image_width, image_height):
    px_row_init = tile_row_pos; px_col_init = tile_col_pos
    px_row_end = min(px_row_init + tile_height, image_height)
    px_col_end = min(px_col_init + tile_width, image_width)
    return px_row_init, px_row_end, px_col_init, px_col_end

def _tile2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end, vae_scale_factor):
    return px_row_init // vae_scale_factor, px_row_end // vae_scale_factor, px_col_init // vae_scale_factor, px_col_end // vae_scale_factor

def release_memory(device):
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

class ZImageMoDTilingPipeline(ZImagePipeline):
    class TileWeightingMethod(Enum):
        COSINE = "Cosine"; GAUSSIAN = "Gaussian"

    def _generate_gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype, sigma=0.4):
        latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor
        x, y = np.linspace(-1, 1, latent_width), np.linspace(-1, 1, latent_height)
        xx, yy = np.meshgrid(x, y)
        gaussian_weight_np = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))
        weights_torch_f32 = torch.tensor(gaussian_weight_np, device=device, dtype=torch.float32)
        weights_torch_target_dtype = weights_torch_f32.to(dtype)
        return torch.tile(weights_torch_target_dtype, (nbatches, self.transformer.in_channels, 1, 1))

    def _generate_cosine_weights(self, tile_width, tile_height, nbatches, device, dtype):
        latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor
        x, y = np.arange(latent_width), np.arange(latent_height)
        mid_x, mid_y = (latent_width - 1) / 2, (latent_height - 1) / 2
        x_probs, y_probs = np.cos(np.pi * (x - mid_x) / latent_width), np.cos(np.pi * (y - mid_y) / latent_height)
        weights_np = np.outer(y_probs, x_probs)
        weights_torch = torch.tensor(weights_np, device=device, dtype=dtype)
        return torch.tile(weights_torch, (nbatches, self.transformer.in_channels, 1, 1))

    def prepare_tiles_weights(self, y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, dtype):
        tile_weights = np.empty((len(y_steps), len(x_steps)), dtype=object)
        for row, y_start in enumerate(y_steps):
            for col, x_start in enumerate(x_steps):
                _, px_row_end, _, px_col_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height)
                current_tile_h, current_tile_w = px_row_end - y_start, px_col_end - x_start
                if tile_weighting_method == self.TileWeightingMethod.COSINE.value:
                    tile_weights[row, col] = self._generate_cosine_weights(current_tile_w, current_tile_h, batch_size, device, dtype)
                else:
                    tile_weights[row, col] = self._generate_gaussian_weights(current_tile_w, current_tile_h, batch_size, device, dtype, sigma=tile_gaussian_sigma)
        return tile_weights

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[List[str]]],
        height: int = 1024,
        width: int = 1024,
        num_inference_steps: int = 9,
        guidance_scale: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        max_tile_size: int = 1024,
        tile_overlap: int = 256,
        tile_weighting_method: str = "Cosine",
        tile_gaussian_sigma: float = 0.4,
        guidance_scale_tiles: Optional[List[List[float]]] = None,
        max_sequence_length: int = 512,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,        
        **kwargs,
    ):
        # Handle negative_prompt from kwargs for backward compatibility if needed, but default to empty
        negative_prompt = kwargs.get("negative_prompt", "")
        
        device = self._execution_device; batch_size = 1 
        is_prompt_grid = isinstance(prompt, list) and all(isinstance(row, list) for row in prompt)
        PIXEL_MULTIPLE = self.vae_scale_factor * 2 # 16

        # Grid and Dimension Calculation
        if is_prompt_grid:
            grid_rows, grid_cols = len(prompt), len(prompt[0])
            tile_width = (width + (grid_cols - 1) * tile_overlap) // grid_cols
            tile_height = (height + (grid_rows - 1) * tile_overlap) // grid_rows
            tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE
            final_width = tile_width * grid_cols - (grid_cols - 1) * tile_overlap
            final_height = tile_height * grid_rows - (grid_rows - 1) * tile_overlap
            x_steps = [i * (tile_width - tile_overlap) for i in range(grid_cols)]
            y_steps = [i * (tile_height - tile_overlap) for i in range(grid_rows)]
        else: 
            final_width, final_height = width, height
            tile_width, tile_height = _adaptive_tile_size((final_width, final_height), max_tile_size=max_tile_size)
            tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE
            y_steps = _calculate_tile_positions(final_height, tile_height, tile_overlap)
            x_steps = _calculate_tile_positions(final_width, tile_width, tile_overlap)
            grid_rows, grid_cols = len(y_steps), len(x_steps)
        
        # Prompt Encoding
        text_embeddings = []
        for r in range(grid_rows):
            row_embeddings = []
            for c in range(grid_cols):
                p = prompt[r][c] if is_prompt_grid else prompt                
                prompt_embeds, _ = self.encode_prompt(
                    prompt=p, do_classifier_free_guidance=False, device=device, max_sequence_length=max_sequence_length
                )
                row_embeddings.append({"prompt": prompt_embeds})
            text_embeddings.append(row_embeddings)
        
        # Latent and Scheduler Setup
        num_latent_channels = self.transformer.in_channels
        latents = self.prepare_latents(
            batch_size, num_latent_channels, final_height, final_width, torch.float32, device, generator
        )
        
        image_seq_len = (tile_height // 2) * (tile_width // 2)
        mu = calculate_shift(image_seq_len)
        self.scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
        timesteps = self.scheduler.timesteps
        
        # Prepare Weights and Offload
        tile_weights = self.prepare_tiles_weights(y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, latents.dtype)
        self.text_encoder.to("cpu"); 
        release_memory(device)

        # Denoising Loop
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for t in timesteps:
                denoised_step_canvas = torch.zeros_like(latents)
                contributors = torch.zeros_like(latents)

                for r, y_start in enumerate(y_steps):
                    for c, x_start in enumerate(x_steps):
                        px_r_init, px_r_end, px_c_init, px_c_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height)
                        r_init, r_end, c_init, c_end = _tile2latent_indices(px_r_init, px_r_end, px_c_init, px_c_end, self.vae_scale_factor)
                        
                        tile_latents = latents[:, :, r_init:r_end, c_init:c_end]                        
                        latents_typed = tile_latents.to(self.transformer.dtype)
                        embeds = text_embeddings[r][c]                        
                        timestep_model_input = t.expand(latents_typed.shape[0])
                        timestep_norm = (1000 - timestep_model_input) / 1000
                                                
                        latent_model_input_list = list(latents_typed.unsqueeze(2).unbind(dim=0))
                                                
                        model_out_list = self.transformer(
                            latent_model_input_list,
                            timestep_norm,
                            embeds["prompt"],
                        )[0]

                        
                        noise_pred_tile = model_out_list[0].float()                        
                        noise_pred_tile = -noise_pred_tile.squeeze(1)
                        
                        denoised_step_canvas[:, :, r_init:r_end, c_init:c_end] += noise_pred_tile * tile_weights[r, c]
                        contributors[:, :, r_init:r_end, c_init:c_end] += tile_weights[r, c]
                        
                noise_pred_canvas = denoised_step_canvas / contributors
                latents = self.scheduler.step(noise_pred_canvas.to(torch.float32), t, latents).prev_sample
                progress_bar.update()

        # Post-processing
        if output_type == "latent":
            image = latents
        else:
            self.vae.to(device)
            latents = latents.to(self.vae.dtype)
            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
            image = self.vae.decode(latents, return_dict=False)[0]
            image = self.image_processor.postprocess(image, output_type=output_type)
            
        self.maybe_free_model_hooks(); 
        return ZImagePipelineOutput(images=image) if return_dict else image