Upload block.py with huggingface_hub
Browse files
block.py
CHANGED
|
@@ -8,589 +8,16 @@ MultiDiffusion tiled upscaling for Stable Diffusion XL using Modular Diffusers.
|
|
| 8 |
# utils_tiling
|
| 9 |
# ============================================================
|
| 10 |
|
| 11 |
-
|
| 12 |
-
#
|
| 13 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 14 |
-
# you may not use this file except in compliance with the License.
|
| 15 |
-
# You may obtain a copy of the License at
|
| 16 |
-
#
|
| 17 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 18 |
-
#
|
| 19 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 20 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 21 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 22 |
-
# See the License for the specific language governing permissions and
|
| 23 |
-
# limitations under the License.
|
| 24 |
-
|
| 25 |
-
"""Pure utility functions for tiled upscale workflows.
|
| 26 |
-
|
| 27 |
-
Supports:
|
| 28 |
-
- Linear (raster) and chess (checkerboard) tile traversal
|
| 29 |
-
- Non-overlapping core paste and gradient overlap blending
|
| 30 |
-
- Seam-fix band planning along tile boundaries
|
| 31 |
-
- Linear feathered mask blending for seam-fix bands
|
| 32 |
-
"""
|
| 33 |
-
|
| 34 |
-
from dataclasses import dataclass, field
|
| 35 |
-
|
| 36 |
-
import numpy as np
|
| 37 |
-
import PIL.Image
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
@dataclass
|
| 41 |
-
class TileSpec:
|
| 42 |
-
"""Specification for a single tile, distinguishing the core output region
|
| 43 |
-
from the padded crop region used for denoising.
|
| 44 |
-
|
| 45 |
-
Attributes:
|
| 46 |
-
core_x: Left edge of the core region in the output canvas.
|
| 47 |
-
core_y: Top edge of the core region in the output canvas.
|
| 48 |
-
core_w: Width of the core region (what this tile is responsible for pasting).
|
| 49 |
-
core_h: Height of the core region.
|
| 50 |
-
crop_x: Left edge of the padded crop region in the source image.
|
| 51 |
-
crop_y: Top edge of the padded crop region in the source image.
|
| 52 |
-
crop_w: Width of the padded crop region (what gets denoised).
|
| 53 |
-
crop_h: Height of the padded crop region.
|
| 54 |
-
paste_x: X offset of the core region within the crop (left padding amount).
|
| 55 |
-
paste_y: Y offset of the core region within the crop (top padding amount).
|
| 56 |
-
"""
|
| 57 |
-
|
| 58 |
-
core_x: int
|
| 59 |
-
core_y: int
|
| 60 |
-
core_w: int
|
| 61 |
-
core_h: int
|
| 62 |
-
crop_x: int
|
| 63 |
-
crop_y: int
|
| 64 |
-
crop_w: int
|
| 65 |
-
crop_h: int
|
| 66 |
-
paste_x: int
|
| 67 |
-
paste_y: int
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
@dataclass
|
| 71 |
-
class SeamFixSpec:
|
| 72 |
-
"""Specification for a seam-fix band along a tile boundary.
|
| 73 |
-
|
| 74 |
-
Attributes:
|
| 75 |
-
band_x: Left edge of the band in the output canvas.
|
| 76 |
-
band_y: Top edge of the band in the output canvas.
|
| 77 |
-
band_w: Width of the band.
|
| 78 |
-
band_h: Height of the band.
|
| 79 |
-
crop_x: Left edge of the padded crop for denoising.
|
| 80 |
-
crop_y: Top edge of the padded crop for denoising.
|
| 81 |
-
crop_w: Width of the padded crop.
|
| 82 |
-
crop_h: Height of the padded crop.
|
| 83 |
-
paste_x: X offset of the band within the crop.
|
| 84 |
-
paste_y: Y offset of the band within the crop.
|
| 85 |
-
orientation: 'horizontal' or 'vertical'.
|
| 86 |
-
"""
|
| 87 |
-
|
| 88 |
-
band_x: int
|
| 89 |
-
band_y: int
|
| 90 |
-
band_w: int
|
| 91 |
-
band_h: int
|
| 92 |
-
crop_x: int
|
| 93 |
-
crop_y: int
|
| 94 |
-
crop_w: int
|
| 95 |
-
crop_h: int
|
| 96 |
-
paste_x: int
|
| 97 |
-
paste_y: int
|
| 98 |
-
orientation: str = field(default="horizontal")
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def validate_tile_params(tile_size: int, tile_padding: int) -> None:
|
| 102 |
-
"""Validate tile parameters strictly.
|
| 103 |
-
|
| 104 |
-
Args:
|
| 105 |
-
tile_size: Base tile size in pixels.
|
| 106 |
-
tile_padding: Overlap padding on each side.
|
| 107 |
-
|
| 108 |
-
Raises:
|
| 109 |
-
ValueError: If parameters are out of range.
|
| 110 |
-
"""
|
| 111 |
-
if tile_size <= 0:
|
| 112 |
-
raise ValueError(f"`tile_size` must be positive, got {tile_size}.")
|
| 113 |
-
if tile_padding < 0:
|
| 114 |
-
raise ValueError(f"`tile_padding` must be non-negative, got {tile_padding}.")
|
| 115 |
-
if tile_padding >= tile_size // 2:
|
| 116 |
-
raise ValueError(
|
| 117 |
-
f"`tile_padding` must be less than tile_size // 2. "
|
| 118 |
-
f"Got tile_padding={tile_padding}, tile_size={tile_size} "
|
| 119 |
-
f"(max allowed: {tile_size // 2 - 1})."
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def plan_tiles_linear(
|
| 124 |
-
image_width: int,
|
| 125 |
-
image_height: int,
|
| 126 |
-
tile_size: int = 512,
|
| 127 |
-
tile_padding: int = 32,
|
| 128 |
-
) -> list[TileSpec]:
|
| 129 |
-
"""Plan tiles in a left-to-right, top-to-bottom (linear/raster) traversal order.
|
| 130 |
-
|
| 131 |
-
Each tile is a ``TileSpec`` with separate core (output responsibility) and
|
| 132 |
-
crop (denoised region with padding context) bounds. The crop region extends
|
| 133 |
-
beyond the core by ``tile_padding`` on each side, clamped to image edges.
|
| 134 |
-
|
| 135 |
-
Args:
|
| 136 |
-
image_width: Width of the image to tile.
|
| 137 |
-
image_height: Height of the image to tile.
|
| 138 |
-
tile_size: Base tile size. The core region of each tile is
|
| 139 |
-
``tile_size - 2 * tile_padding``.
|
| 140 |
-
tile_padding: Number of overlap pixels on each side.
|
| 141 |
-
|
| 142 |
-
Returns:
|
| 143 |
-
List of ``TileSpec`` in linear traversal order.
|
| 144 |
-
"""
|
| 145 |
-
validate_tile_params(tile_size, tile_padding)
|
| 146 |
-
|
| 147 |
-
core_size = tile_size - 2 * tile_padding
|
| 148 |
-
tiles: list[TileSpec] = []
|
| 149 |
-
|
| 150 |
-
core_y = 0
|
| 151 |
-
while core_y < image_height:
|
| 152 |
-
core_h = min(core_size, image_height - core_y)
|
| 153 |
-
|
| 154 |
-
core_x = 0
|
| 155 |
-
while core_x < image_width:
|
| 156 |
-
core_w = min(core_size, image_width - core_x)
|
| 157 |
-
|
| 158 |
-
# Compute padded crop region, clamped to image bounds
|
| 159 |
-
crop_x = max(0, core_x - tile_padding)
|
| 160 |
-
crop_y = max(0, core_y - tile_padding)
|
| 161 |
-
crop_x2 = min(image_width, core_x + core_w + tile_padding)
|
| 162 |
-
crop_y2 = min(image_height, core_y + core_h + tile_padding)
|
| 163 |
-
crop_w = crop_x2 - crop_x
|
| 164 |
-
crop_h = crop_y2 - crop_y
|
| 165 |
-
|
| 166 |
-
# Where the core sits within the crop
|
| 167 |
-
paste_x = core_x - crop_x
|
| 168 |
-
paste_y = core_y - crop_y
|
| 169 |
-
|
| 170 |
-
tiles.append(
|
| 171 |
-
TileSpec(
|
| 172 |
-
core_x=core_x,
|
| 173 |
-
core_y=core_y,
|
| 174 |
-
core_w=core_w,
|
| 175 |
-
core_h=core_h,
|
| 176 |
-
crop_x=crop_x,
|
| 177 |
-
crop_y=crop_y,
|
| 178 |
-
crop_w=crop_w,
|
| 179 |
-
crop_h=crop_h,
|
| 180 |
-
paste_x=paste_x,
|
| 181 |
-
paste_y=paste_y,
|
| 182 |
-
)
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
core_x += core_size
|
| 186 |
-
core_y += core_size
|
| 187 |
-
|
| 188 |
-
return tiles
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def crop_tile(image: PIL.Image.Image, tile: TileSpec) -> PIL.Image.Image:
|
| 192 |
-
"""Crop the padded region of a tile from a PIL image.
|
| 193 |
-
|
| 194 |
-
Args:
|
| 195 |
-
image: Source image.
|
| 196 |
-
tile: Tile specification.
|
| 197 |
-
|
| 198 |
-
Returns:
|
| 199 |
-
Cropped PIL image of the padded crop region.
|
| 200 |
-
"""
|
| 201 |
-
return image.crop((tile.crop_x, tile.crop_y, tile.crop_x + tile.crop_w, tile.crop_y + tile.crop_h))
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def extract_core_from_decoded(decoded_image: np.ndarray, tile: TileSpec) -> np.ndarray:
|
| 205 |
-
"""Extract the core region from a decoded tile image.
|
| 206 |
-
|
| 207 |
-
Args:
|
| 208 |
-
decoded_image: Decoded tile as numpy array, shape (crop_h, crop_w, C).
|
| 209 |
-
tile: Tile specification.
|
| 210 |
-
|
| 211 |
-
Returns:
|
| 212 |
-
Core region as numpy array, shape (core_h, core_w, C).
|
| 213 |
-
"""
|
| 214 |
-
return decoded_image[
|
| 215 |
-
tile.paste_y : tile.paste_y + tile.core_h,
|
| 216 |
-
tile.paste_x : tile.paste_x + tile.core_w,
|
| 217 |
-
]
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
def paste_core_into_canvas(
|
| 221 |
-
canvas: np.ndarray,
|
| 222 |
-
core_image: np.ndarray,
|
| 223 |
-
tile: TileSpec,
|
| 224 |
-
) -> None:
|
| 225 |
-
"""Paste the core region of a decoded tile directly into the output canvas.
|
| 226 |
-
|
| 227 |
-
No blending — the core regions tile the canvas without overlap.
|
| 228 |
-
|
| 229 |
-
Args:
|
| 230 |
-
canvas: Output canvas, shape (H, W, C), float32. Modified in-place.
|
| 231 |
-
core_image: Core tile pixels, shape (core_h, core_w, C), float32.
|
| 232 |
-
tile: Tile specification.
|
| 233 |
-
"""
|
| 234 |
-
canvas[tile.core_y : tile.core_y + tile.core_h, tile.core_x : tile.core_x + tile.core_w] = core_image
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
# =============================================================================
|
| 238 |
-
# Chess (checkerboard) traversal
|
| 239 |
-
# =============================================================================
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
def plan_tiles_chess(
|
| 243 |
-
image_width: int,
|
| 244 |
-
image_height: int,
|
| 245 |
-
tile_size: int = 512,
|
| 246 |
-
tile_padding: int = 32,
|
| 247 |
-
) -> list[TileSpec]:
|
| 248 |
-
"""Plan tiles in a checkerboard (chess) traversal order.
|
| 249 |
-
|
| 250 |
-
Two passes: first all "white" tiles (row+col both even or both odd),
|
| 251 |
-
then all "black" tiles. This ensures adjacent tiles are never processed
|
| 252 |
-
consecutively, reducing visible seam patterns.
|
| 253 |
-
|
| 254 |
-
Args:
|
| 255 |
-
image_width: Width of the image to tile.
|
| 256 |
-
image_height: Height of the image to tile.
|
| 257 |
-
tile_size: Base tile size.
|
| 258 |
-
tile_padding: Number of overlap pixels on each side.
|
| 259 |
-
|
| 260 |
-
Returns:
|
| 261 |
-
List of ``TileSpec`` in chess traversal order.
|
| 262 |
-
"""
|
| 263 |
-
validate_tile_params(tile_size, tile_padding)
|
| 264 |
-
|
| 265 |
-
core_size = tile_size - 2 * tile_padding
|
| 266 |
-
|
| 267 |
-
# Build grid of all tiles with (row, col) indices
|
| 268 |
-
grid: list[tuple[int, int, TileSpec]] = []
|
| 269 |
-
|
| 270 |
-
row = 0
|
| 271 |
-
core_y = 0
|
| 272 |
-
while core_y < image_height:
|
| 273 |
-
core_h = min(core_size, image_height - core_y)
|
| 274 |
-
|
| 275 |
-
col = 0
|
| 276 |
-
core_x = 0
|
| 277 |
-
while core_x < image_width:
|
| 278 |
-
core_w = min(core_size, image_width - core_x)
|
| 279 |
-
|
| 280 |
-
crop_x = max(0, core_x - tile_padding)
|
| 281 |
-
crop_y = max(0, core_y - tile_padding)
|
| 282 |
-
crop_x2 = min(image_width, core_x + core_w + tile_padding)
|
| 283 |
-
crop_y2 = min(image_height, core_y + core_h + tile_padding)
|
| 284 |
-
crop_w = crop_x2 - crop_x
|
| 285 |
-
crop_h = crop_y2 - crop_y
|
| 286 |
-
|
| 287 |
-
paste_x = core_x - crop_x
|
| 288 |
-
paste_y = core_y - crop_y
|
| 289 |
-
|
| 290 |
-
tile = TileSpec(
|
| 291 |
-
core_x=core_x, core_y=core_y, core_w=core_w, core_h=core_h,
|
| 292 |
-
crop_x=crop_x, crop_y=crop_y, crop_w=crop_w, crop_h=crop_h,
|
| 293 |
-
paste_x=paste_x, paste_y=paste_y,
|
| 294 |
-
)
|
| 295 |
-
grid.append((row, col, tile))
|
| 296 |
-
|
| 297 |
-
col += 1
|
| 298 |
-
core_x += core_size
|
| 299 |
-
row += 1
|
| 300 |
-
core_y += core_size
|
| 301 |
-
|
| 302 |
-
# Separate into white and black squares
|
| 303 |
-
white = [t for r, c, t in grid if (r + c) % 2 == 0]
|
| 304 |
-
black = [t for r, c, t in grid if (r + c) % 2 == 1]
|
| 305 |
-
|
| 306 |
-
return white + black
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
# =============================================================================
|
| 310 |
-
# Gradient overlap blending
|
| 311 |
-
# =============================================================================
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
def make_gradient_mask(
|
| 315 |
-
core_h: int,
|
| 316 |
-
core_w: int,
|
| 317 |
-
overlap: int,
|
| 318 |
-
at_top: bool = False,
|
| 319 |
-
at_bottom: bool = False,
|
| 320 |
-
at_left: bool = False,
|
| 321 |
-
at_right: bool = False,
|
| 322 |
-
) -> np.ndarray:
|
| 323 |
-
"""Create a boundary-aware gradient blending mask for a tile's core region.
|
| 324 |
-
|
| 325 |
-
The mask is 1.0 in the interior and linearly ramps from 0 to 1 in the
|
| 326 |
-
overlap zones along interior edges only. Edges that touch the canvas
|
| 327 |
-
boundary (indicated by ``at_*`` flags) stay at 1.0 to prevent black borders.
|
| 328 |
-
|
| 329 |
-
Args:
|
| 330 |
-
core_h: Height of the core region.
|
| 331 |
-
core_w: Width of the core region.
|
| 332 |
-
overlap: Width of the gradient ramp in pixels.
|
| 333 |
-
at_top: True if tile is at the top edge of the canvas.
|
| 334 |
-
at_bottom: True if tile is at the bottom edge of the canvas.
|
| 335 |
-
at_left: True if tile is at the left edge of the canvas.
|
| 336 |
-
at_right: True if tile is at the right edge of the canvas.
|
| 337 |
-
|
| 338 |
-
Returns:
|
| 339 |
-
Mask of shape (core_h, core_w), float32, values in [0, 1].
|
| 340 |
-
"""
|
| 341 |
-
if overlap <= 0:
|
| 342 |
-
return np.ones((core_h, core_w), dtype=np.float32)
|
| 343 |
-
|
| 344 |
-
mask = np.ones((core_h, core_w), dtype=np.float32)
|
| 345 |
-
|
| 346 |
-
# Only fade on interior edges (not canvas boundaries)
|
| 347 |
-
ramp_w = min(overlap, core_w)
|
| 348 |
-
if ramp_w > 0 and not at_left:
|
| 349 |
-
left_ramp = np.linspace(0.0, 1.0, ramp_w, dtype=np.float32)
|
| 350 |
-
mask[:, :ramp_w] = np.minimum(mask[:, :ramp_w], left_ramp[np.newaxis, :])
|
| 351 |
-
if ramp_w > 0 and not at_right:
|
| 352 |
-
right_ramp = np.linspace(1.0, 0.0, ramp_w, dtype=np.float32)
|
| 353 |
-
mask[:, -ramp_w:] = np.minimum(mask[:, -ramp_w:], right_ramp[np.newaxis, :])
|
| 354 |
-
|
| 355 |
-
ramp_h = min(overlap, core_h)
|
| 356 |
-
if ramp_h > 0 and not at_top:
|
| 357 |
-
top_ramp = np.linspace(0.0, 1.0, ramp_h, dtype=np.float32)
|
| 358 |
-
mask[:ramp_h, :] = np.minimum(mask[:ramp_h, :], top_ramp[:, np.newaxis])
|
| 359 |
-
if ramp_h > 0 and not at_bottom:
|
| 360 |
-
bottom_ramp = np.linspace(1.0, 0.0, ramp_h, dtype=np.float32)
|
| 361 |
-
mask[-ramp_h:, :] = np.minimum(mask[-ramp_h:, :], bottom_ramp[:, np.newaxis])
|
| 362 |
-
|
| 363 |
-
return mask
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
def paste_core_into_canvas_blended(
|
| 367 |
-
canvas: np.ndarray,
|
| 368 |
-
weight_map: np.ndarray,
|
| 369 |
-
core_image: np.ndarray,
|
| 370 |
-
tile: TileSpec,
|
| 371 |
-
overlap: int,
|
| 372 |
-
) -> None:
|
| 373 |
-
"""Paste a tile's core into the canvas using boundary-aware gradient blending.
|
| 374 |
-
|
| 375 |
-
Uses accumulated weighted sum approach: canvas stores weighted sum,
|
| 376 |
-
weight_map stores total weights. Finalize by dividing canvas / weight_map.
|
| 377 |
-
|
| 378 |
-
Args:
|
| 379 |
-
canvas: Output canvas, shape (H, W, C), float32. Modified in-place.
|
| 380 |
-
weight_map: Weight accumulator, shape (H, W), float32. Modified in-place.
|
| 381 |
-
core_image: Core tile pixels, shape (core_h, core_w, C), float32.
|
| 382 |
-
tile: Tile specification.
|
| 383 |
-
overlap: Gradient overlap width in pixels.
|
| 384 |
-
"""
|
| 385 |
-
canvas_h, canvas_w = canvas.shape[:2]
|
| 386 |
-
|
| 387 |
-
mask = make_gradient_mask(
|
| 388 |
-
tile.core_h, tile.core_w, overlap,
|
| 389 |
-
at_top=(tile.core_y == 0),
|
| 390 |
-
at_bottom=(tile.core_y + tile.core_h >= canvas_h),
|
| 391 |
-
at_left=(tile.core_x == 0),
|
| 392 |
-
at_right=(tile.core_x + tile.core_w >= canvas_w),
|
| 393 |
-
)
|
| 394 |
-
|
| 395 |
-
y1, y2 = tile.core_y, tile.core_y + tile.core_h
|
| 396 |
-
x1, x2 = tile.core_x, tile.core_x + tile.core_w
|
| 397 |
-
|
| 398 |
-
canvas[y1:y2, x1:x2] += core_image * mask[:, :, np.newaxis]
|
| 399 |
-
weight_map[y1:y2, x1:x2] += mask
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
def finalize_blended_canvas(canvas: np.ndarray, weight_map: np.ndarray) -> np.ndarray:
|
| 403 |
-
"""Normalize the blended canvas by dividing by accumulated weights.
|
| 404 |
-
|
| 405 |
-
Pixels with zero weight (uncovered) are filled from the raw weighted sum
|
| 406 |
-
to avoid black borders from epsilon division.
|
| 407 |
-
|
| 408 |
-
Args:
|
| 409 |
-
canvas: Weighted sum canvas, shape (H, W, C).
|
| 410 |
-
weight_map: Weight accumulator, shape (H, W).
|
| 411 |
-
|
| 412 |
-
Returns:
|
| 413 |
-
Normalized canvas, shape (H, W, C), float32.
|
| 414 |
-
"""
|
| 415 |
-
result = np.copy(canvas)
|
| 416 |
-
covered = weight_map > 0
|
| 417 |
-
result[covered] = canvas[covered] / weight_map[covered, np.newaxis]
|
| 418 |
-
# Uncovered pixels stay as-is (zero) — should not occur with boundary-aware masks
|
| 419 |
-
return result
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
# =============================================================================
|
| 423 |
-
# Seam-fix band planning
|
| 424 |
-
# =============================================================================
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
def plan_seam_fix_bands(
|
| 428 |
-
tiles: list[TileSpec],
|
| 429 |
-
image_width: int,
|
| 430 |
-
image_height: int,
|
| 431 |
-
seam_fix_width: int = 64,
|
| 432 |
-
seam_fix_padding: int = 16,
|
| 433 |
-
) -> list[SeamFixSpec]:
|
| 434 |
-
"""Plan seam-fix bands along tile boundaries.
|
| 435 |
-
|
| 436 |
-
For each pair of adjacent core regions, creates a band centered on the
|
| 437 |
-
shared boundary. Bands are denoised in a second pass to smooth seams.
|
| 438 |
-
|
| 439 |
-
Args:
|
| 440 |
-
tiles: The tile plan (from plan_tiles_linear or plan_tiles_chess).
|
| 441 |
-
image_width: Full image width.
|
| 442 |
-
image_height: Full image height.
|
| 443 |
-
seam_fix_width: Width of the seam-fix band in pixels.
|
| 444 |
-
seam_fix_padding: Additional padding around each band for denoise context.
|
| 445 |
-
|
| 446 |
-
Returns:
|
| 447 |
-
List of ``SeamFixSpec`` for all seam boundaries.
|
| 448 |
-
"""
|
| 449 |
-
if seam_fix_width < 0:
|
| 450 |
-
raise ValueError(f"`seam_fix_width` must be non-negative, got {seam_fix_width}.")
|
| 451 |
-
if seam_fix_width == 0:
|
| 452 |
-
return []
|
| 453 |
-
if seam_fix_padding < 0:
|
| 454 |
-
raise ValueError(f"`seam_fix_padding` must be non-negative, got {seam_fix_padding}.")
|
| 455 |
-
|
| 456 |
-
# Collect unique boundary positions
|
| 457 |
-
h_boundaries: set[tuple[int, int, int]] = set() # (y, x_start, x_end)
|
| 458 |
-
v_boundaries: set[tuple[int, int, int]] = set() # (x, y_start, y_end)
|
| 459 |
-
|
| 460 |
-
for tile in tiles:
|
| 461 |
-
# Bottom edge of this tile → horizontal seam
|
| 462 |
-
bottom_y = tile.core_y + tile.core_h
|
| 463 |
-
if bottom_y < image_height:
|
| 464 |
-
h_boundaries.add((bottom_y, tile.core_x, tile.core_x + tile.core_w))
|
| 465 |
-
|
| 466 |
-
# Right edge → vertical seam
|
| 467 |
-
right_x = tile.core_x + tile.core_w
|
| 468 |
-
if right_x < image_width:
|
| 469 |
-
v_boundaries.add((right_x, tile.core_y, tile.core_y + tile.core_h))
|
| 470 |
-
|
| 471 |
-
bands: list[SeamFixSpec] = []
|
| 472 |
-
half_left = seam_fix_width // 2
|
| 473 |
-
half_right = seam_fix_width - half_left
|
| 474 |
-
|
| 475 |
-
for y, x_start, x_end in sorted(h_boundaries):
|
| 476 |
-
band_y = max(0, y - half_left)
|
| 477 |
-
band_y2 = min(image_height, y + half_right)
|
| 478 |
-
band_h = band_y2 - band_y
|
| 479 |
-
band_w = x_end - x_start
|
| 480 |
-
|
| 481 |
-
crop_x = max(0, x_start - seam_fix_padding)
|
| 482 |
-
crop_y = max(0, band_y - seam_fix_padding)
|
| 483 |
-
crop_x2 = min(image_width, x_end + seam_fix_padding)
|
| 484 |
-
crop_y2 = min(image_height, band_y2 + seam_fix_padding)
|
| 485 |
-
|
| 486 |
-
bands.append(SeamFixSpec(
|
| 487 |
-
band_x=x_start, band_y=band_y, band_w=band_w, band_h=band_h,
|
| 488 |
-
crop_x=crop_x, crop_y=crop_y,
|
| 489 |
-
crop_w=crop_x2 - crop_x, crop_h=crop_y2 - crop_y,
|
| 490 |
-
paste_x=x_start - crop_x, paste_y=band_y - crop_y,
|
| 491 |
-
orientation="horizontal",
|
| 492 |
-
))
|
| 493 |
-
|
| 494 |
-
for x, y_start, y_end in sorted(v_boundaries):
|
| 495 |
-
band_x = max(0, x - half_left)
|
| 496 |
-
band_x2 = min(image_width, x + half_right)
|
| 497 |
-
band_w = band_x2 - band_x
|
| 498 |
-
band_h = y_end - y_start
|
| 499 |
-
|
| 500 |
-
crop_x = max(0, band_x - seam_fix_padding)
|
| 501 |
-
crop_y = max(0, y_start - seam_fix_padding)
|
| 502 |
-
crop_x2 = min(image_width, band_x2 + seam_fix_padding)
|
| 503 |
-
crop_y2 = min(image_height, y_end + seam_fix_padding)
|
| 504 |
-
|
| 505 |
-
bands.append(SeamFixSpec(
|
| 506 |
-
band_x=band_x, band_y=y_start, band_w=band_w, band_h=band_h,
|
| 507 |
-
crop_x=crop_x, crop_y=crop_y,
|
| 508 |
-
crop_w=crop_x2 - crop_x, crop_h=crop_y2 - crop_y,
|
| 509 |
-
paste_x=band_x - crop_x, paste_y=y_start - crop_y,
|
| 510 |
-
orientation="vertical",
|
| 511 |
-
))
|
| 512 |
-
|
| 513 |
-
return bands
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
def extract_band_from_decoded(decoded_image: np.ndarray, band: SeamFixSpec) -> np.ndarray:
|
| 517 |
-
"""Extract the band region from a decoded seam-fix image."""
|
| 518 |
-
return decoded_image[
|
| 519 |
-
band.paste_y : band.paste_y + band.band_h,
|
| 520 |
-
band.paste_x : band.paste_x + band.band_w,
|
| 521 |
-
]
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
def make_seam_fix_mask(band: SeamFixSpec, mask_blur: int = 8) -> np.ndarray:
|
| 525 |
-
"""Create a linearly-feathered mask for a seam-fix band.
|
| 526 |
-
|
| 527 |
-
The mask is 1.0 at the center of the seam and linearly fades to 0.0
|
| 528 |
-
at the edges perpendicular to the seam orientation, so the seam-fix
|
| 529 |
-
blends smoothly with the surrounding tile results.
|
| 530 |
-
|
| 531 |
-
Args:
|
| 532 |
-
band: Seam-fix band specification.
|
| 533 |
-
mask_blur: Width of the linear feather ramp in pixels.
|
| 534 |
|
| 535 |
-
|
| 536 |
-
Mask of shape (band_h, band_w), float32, values in [0, 1].
|
| 537 |
-
"""
|
| 538 |
-
if mask_blur <= 0:
|
| 539 |
-
return np.ones((band.band_h, band.band_w), dtype=np.float32)
|
| 540 |
-
|
| 541 |
-
mask = np.ones((band.band_h, band.band_w), dtype=np.float32)
|
| 542 |
-
|
| 543 |
-
if band.orientation == "horizontal":
|
| 544 |
-
# Fade along height (top/bottom edges)
|
| 545 |
-
ramp = min(mask_blur, band.band_h // 2)
|
| 546 |
-
if ramp > 0:
|
| 547 |
-
top_ramp = np.linspace(0.0, 1.0, ramp, dtype=np.float32)
|
| 548 |
-
mask[:ramp, :] = top_ramp[:, np.newaxis]
|
| 549 |
-
bottom_ramp = np.linspace(1.0, 0.0, ramp, dtype=np.float32)
|
| 550 |
-
mask[-ramp:, :] = bottom_ramp[:, np.newaxis]
|
| 551 |
-
else:
|
| 552 |
-
# Fade along width (left/right edges)
|
| 553 |
-
ramp = min(mask_blur, band.band_w // 2)
|
| 554 |
-
if ramp > 0:
|
| 555 |
-
left_ramp = np.linspace(0.0, 1.0, ramp, dtype=np.float32)
|
| 556 |
-
mask[:, :ramp] = left_ramp[np.newaxis, :]
|
| 557 |
-
right_ramp = np.linspace(1.0, 0.0, ramp, dtype=np.float32)
|
| 558 |
-
mask[:, -ramp:] = right_ramp[np.newaxis, :]
|
| 559 |
-
|
| 560 |
-
return mask
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
def paste_seam_fix_band(
|
| 564 |
-
canvas: np.ndarray,
|
| 565 |
-
band_image: np.ndarray,
|
| 566 |
-
band: SeamFixSpec,
|
| 567 |
-
mask_blur: int = 8,
|
| 568 |
-
) -> None:
|
| 569 |
-
"""Paste a seam-fix band into the canvas with feathered blending.
|
| 570 |
-
|
| 571 |
-
Args:
|
| 572 |
-
canvas: Output canvas, shape (H, W, C), float32. Modified in-place.
|
| 573 |
-
band_image: Decoded band pixels, shape (band_h, band_w, C), float32.
|
| 574 |
-
band: Seam-fix band specification.
|
| 575 |
-
mask_blur: Feathering width.
|
| 576 |
-
"""
|
| 577 |
-
mask = make_seam_fix_mask(band, mask_blur)
|
| 578 |
-
|
| 579 |
-
y1, y2 = band.band_y, band.band_y + band.band_h
|
| 580 |
-
x1, x2 = band.band_x, band.band_x + band.band_w
|
| 581 |
|
| 582 |
-
|
| 583 |
-
canvas[y1:y2, x1:x2] = existing * (1 - mask[:, :, np.newaxis]) + band_image * mask[:, :, np.newaxis]
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
# =============================================================================
|
| 587 |
-
# Latent-space tile planning for MultiDiffusion
|
| 588 |
-
# =============================================================================
|
| 589 |
|
| 590 |
|
| 591 |
@dataclass
|
| 592 |
class LatentTileSpec:
|
| 593 |
-
"""Tile specification in latent space
|
| 594 |
|
| 595 |
Attributes:
|
| 596 |
y: Top edge in latent pixels.
|
|
@@ -605,26 +32,7 @@ class LatentTileSpec:
|
|
| 605 |
w: int
|
| 606 |
|
| 607 |
|
| 608 |
-
def
|
| 609 |
-
latent_h: int,
|
| 610 |
-
latent_w: int,
|
| 611 |
-
tile_size: int = 64,
|
| 612 |
-
overlap: int = 8,
|
| 613 |
-
) -> list[LatentTileSpec]:
|
| 614 |
-
"""Plan overlapping tiles in latent space for MultiDiffusion.
|
| 615 |
-
|
| 616 |
-
Tiles overlap by ``overlap`` latent pixels. The stride is
|
| 617 |
-
``tile_size - overlap``. Edge tiles are clamped to the latent bounds.
|
| 618 |
-
|
| 619 |
-
Args:
|
| 620 |
-
latent_h: Height of the full latent tensor.
|
| 621 |
-
latent_w: Width of the full latent tensor.
|
| 622 |
-
tile_size: Tile size in latent pixels (e.g., 64 = 512px at scale 8).
|
| 623 |
-
overlap: Overlap in latent pixels (e.g., 8 = 64px at scale 8).
|
| 624 |
-
|
| 625 |
-
Returns:
|
| 626 |
-
List of ``LatentTileSpec``.
|
| 627 |
-
"""
|
| 628 |
if tile_size <= 0:
|
| 629 |
raise ValueError(f"`tile_size` must be positive, got {tile_size}.")
|
| 630 |
if overlap < 0:
|
|
@@ -635,13 +43,26 @@ def plan_latent_tiles(
|
|
| 635 |
f"Got overlap={overlap}, tile_size={tile_size}."
|
| 636 |
)
|
| 637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
stride = tile_size - overlap
|
| 639 |
tiles: list[LatentTileSpec] = []
|
| 640 |
|
| 641 |
y = 0
|
| 642 |
while y < latent_h:
|
| 643 |
h = min(tile_size, latent_h - y)
|
| 644 |
-
# If remaining height is less than tile_size, shift back to get a full tile
|
| 645 |
if h < tile_size and y > 0:
|
| 646 |
y = max(0, latent_h - tile_size)
|
| 647 |
h = latent_h - y
|
|
@@ -666,6 +87,50 @@ def plan_latent_tiles(
|
|
| 666 |
return tiles
|
| 667 |
|
| 668 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
# ============================================================
|
| 670 |
# input
|
| 671 |
# ============================================================
|
|
@@ -684,6 +149,8 @@ def plan_latent_tiles(
|
|
| 684 |
# See the License for the specific language governing permissions and
|
| 685 |
# limitations under the License.
|
| 686 |
|
|
|
|
|
|
|
| 687 |
import PIL.Image
|
| 688 |
import torch
|
| 689 |
|
|
@@ -699,43 +166,29 @@ logger = logging.get_logger(__name__)
|
|
| 699 |
class UltimateSDUpscaleTextEncoderStep(StableDiffusionXLTextEncoderStep):
|
| 700 |
"""SDXL text encoder step that applies guidance scale before encoding.
|
| 701 |
|
| 702 |
-
|
| 703 |
-
embeddings
|
| 704 |
-
current `components.guider.guidance_scale` value.
|
| 705 |
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
text encoding, a previous run can leave the guider in a stale state and
|
| 709 |
-
cause missing negative embeddings on the next run.
|
| 710 |
-
|
| 711 |
-
Also applies a sensible default negative prompt for upscaling when the user
|
| 712 |
-
does not provide one, controlled by ``use_default_negative``.
|
| 713 |
"""
|
| 714 |
|
| 715 |
DEFAULT_NEGATIVE_PROMPT = "blurry, low quality, artifacts, noise, jpeg compression"
|
| 716 |
|
| 717 |
@property
|
| 718 |
def inputs(self) -> list[InputParam]:
|
| 719 |
-
# Keep all SDXL text-encoder inputs and add guidance_scale override.
|
| 720 |
return super().inputs + [
|
| 721 |
InputParam(
|
| 722 |
"guidance_scale",
|
| 723 |
type_hint=float,
|
| 724 |
default=7.5,
|
| 725 |
-
description=
|
| 726 |
-
"Classifier-Free Guidance scale used to configure the guider "
|
| 727 |
-
"before prompt encoding."
|
| 728 |
-
),
|
| 729 |
),
|
| 730 |
InputParam(
|
| 731 |
"use_default_negative",
|
| 732 |
type_hint=bool,
|
| 733 |
default=True,
|
| 734 |
-
description=
|
| 735 |
-
"When True and negative_prompt is None or empty, apply a default "
|
| 736 |
-
"negative prompt optimized for upscaling: "
|
| 737 |
-
"'blurry, low quality, artifacts, noise, jpeg compression'."
|
| 738 |
-
),
|
| 739 |
),
|
| 740 |
]
|
| 741 |
|
|
@@ -747,7 +200,6 @@ class UltimateSDUpscaleTextEncoderStep(StableDiffusionXLTextEncoderStep):
|
|
| 747 |
if hasattr(components, "guider") and components.guider is not None:
|
| 748 |
components.guider.guidance_scale = guidance_scale
|
| 749 |
|
| 750 |
-
# Apply default negative prompt if user didn't provide one
|
| 751 |
use_default_negative = getattr(block_state, "use_default_negative", True)
|
| 752 |
if use_default_negative:
|
| 753 |
neg = getattr(block_state, "negative_prompt", None)
|
|
@@ -759,56 +211,27 @@ class UltimateSDUpscaleTextEncoderStep(StableDiffusionXLTextEncoderStep):
|
|
| 759 |
|
| 760 |
|
| 761 |
class UltimateSDUpscaleUpscaleStep(ModularPipelineBlocks):
|
| 762 |
-
"""Upscales the input image using Lanczos interpolation.
|
| 763 |
-
|
| 764 |
-
This is the first custom step in the tiled upscaling workflow.
|
| 765 |
-
It takes an input image and upscale factor, producing an upscaled image
|
| 766 |
-
that subsequent tile steps will refine.
|
| 767 |
-
"""
|
| 768 |
|
| 769 |
@property
|
| 770 |
def description(self) -> str:
|
| 771 |
-
return
|
| 772 |
-
"Upscale step that resizes the input image by a given factor.\n"
|
| 773 |
-
"Currently supports Lanczos interpolation. Model-based upscalers "
|
| 774 |
-
"can be added in future passes."
|
| 775 |
-
)
|
| 776 |
|
| 777 |
@property
|
| 778 |
def inputs(self) -> list[InputParam]:
|
| 779 |
return [
|
| 780 |
-
InputParam(
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
description="The input image to upscale and refine.",
|
| 785 |
-
),
|
| 786 |
-
InputParam(
|
| 787 |
-
"upscale_factor",
|
| 788 |
-
type_hint=float,
|
| 789 |
-
default=2.0,
|
| 790 |
-
description="Factor by which to upscale the input image.",
|
| 791 |
-
),
|
| 792 |
]
|
| 793 |
|
| 794 |
@property
|
| 795 |
def intermediate_outputs(self) -> list[OutputParam]:
|
| 796 |
return [
|
| 797 |
-
OutputParam(
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
description="The upscaled image before tile-based refinement.",
|
| 801 |
-
),
|
| 802 |
-
OutputParam(
|
| 803 |
-
"upscaled_width",
|
| 804 |
-
type_hint=int,
|
| 805 |
-
description="Width of the upscaled image.",
|
| 806 |
-
),
|
| 807 |
-
OutputParam(
|
| 808 |
-
"upscaled_height",
|
| 809 |
-
type_hint=int,
|
| 810 |
-
description="Height of the upscaled image.",
|
| 811 |
-
),
|
| 812 |
]
|
| 813 |
|
| 814 |
@torch.no_grad()
|
|
@@ -819,10 +242,7 @@ class UltimateSDUpscaleUpscaleStep(ModularPipelineBlocks):
|
|
| 819 |
upscale_factor = block_state.upscale_factor
|
| 820 |
|
| 821 |
if not isinstance(image, PIL.Image.Image):
|
| 822 |
-
raise ValueError(
|
| 823 |
-
f"Expected `image` to be a PIL.Image.Image, got {type(image)}. "
|
| 824 |
-
"Please pass a PIL image to the pipeline."
|
| 825 |
-
)
|
| 826 |
|
| 827 |
new_width = int(image.width * upscale_factor)
|
| 828 |
new_height = int(image.height * upscale_factor)
|
|
@@ -831,1107 +251,209 @@ class UltimateSDUpscaleUpscaleStep(ModularPipelineBlocks):
|
|
| 831 |
block_state.upscaled_width = new_width
|
| 832 |
block_state.upscaled_height = new_height
|
| 833 |
|
| 834 |
-
logger.info(
|
| 835 |
-
f"Upscaled image from {image.width}x{image.height} to {new_width}x{new_height} "
|
| 836 |
-
f"(factor={upscale_factor})"
|
| 837 |
-
)
|
| 838 |
-
|
| 839 |
-
self.set_block_state(state, block_state)
|
| 840 |
-
return components, state
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
class UltimateSDUpscaleTilePlanStep(ModularPipelineBlocks):
|
| 844 |
-
"""Plans the tile grid for the upscaled image.
|
| 845 |
-
|
| 846 |
-
Generates a list of ``TileSpec`` objects based on the requested tile size
|
| 847 |
-
and padding. Supports linear (raster) and chess (checkerboard) traversal.
|
| 848 |
-
Optionally plans seam-fix bands along tile boundaries.
|
| 849 |
-
"""
|
| 850 |
-
|
| 851 |
-
@property
|
| 852 |
-
def description(self) -> str:
|
| 853 |
-
return (
|
| 854 |
-
"Tile planning step that generates tile coordinates for the upscaled image.\n"
|
| 855 |
-
"Supports 'linear' (raster) and 'chess' (checkerboard) traversal.\n"
|
| 856 |
-
"Optionally plans seam-fix bands along tile boundaries."
|
| 857 |
-
)
|
| 858 |
-
|
| 859 |
-
@property
|
| 860 |
-
def inputs(self) -> list[InputParam]:
|
| 861 |
-
return [
|
| 862 |
-
InputParam("upscaled_width", type_hint=int, required=True,
|
| 863 |
-
description="Width of the upscaled image."),
|
| 864 |
-
InputParam("upscaled_height", type_hint=int, required=True,
|
| 865 |
-
description="Height of the upscaled image."),
|
| 866 |
-
InputParam("tile_size", type_hint=int, default=2048,
|
| 867 |
-
description="Base tile size in pixels. Default 2048 processes most images "
|
| 868 |
-
"in a single pass (seamless). Set to 512 for tiled mode on very large images."),
|
| 869 |
-
InputParam("tile_padding", type_hint=int, default=32,
|
| 870 |
-
description="Number of overlap pixels on each side of a tile. Only relevant when tiling."),
|
| 871 |
-
InputParam("traversal_mode", type_hint=str, default="linear",
|
| 872 |
-
description="Tile traversal order: 'linear' or 'chess'."),
|
| 873 |
-
InputParam("seam_fix_width", type_hint=int, default=0,
|
| 874 |
-
description="Width of seam-fix bands in pixels. 0 disables seam fixing."),
|
| 875 |
-
InputParam("seam_fix_padding", type_hint=int, default=16,
|
| 876 |
-
description="Extra padding around seam-fix bands for denoise context."),
|
| 877 |
-
InputParam("seam_fix_mask_blur", type_hint=int, default=8,
|
| 878 |
-
description="Feathering width for seam-fix blending masks."),
|
| 879 |
-
]
|
| 880 |
-
|
| 881 |
-
@property
|
| 882 |
-
def intermediate_outputs(self) -> list[OutputParam]:
|
| 883 |
-
return [
|
| 884 |
-
OutputParam("tile_plan", type_hint=list,
|
| 885 |
-
description="List of TileSpec defining the tile grid."),
|
| 886 |
-
OutputParam("num_tiles", type_hint=int,
|
| 887 |
-
description="Total number of tiles in the plan."),
|
| 888 |
-
OutputParam("seam_fix_plan", type_hint=list,
|
| 889 |
-
description="List of SeamFixSpec for seam-fix bands (empty if disabled)."),
|
| 890 |
-
OutputParam("seam_fix_mask_blur", type_hint=int,
|
| 891 |
-
description="Feathering width for seam-fix blending."),
|
| 892 |
-
]
|
| 893 |
-
|
| 894 |
-
@torch.no_grad()
|
| 895 |
-
def __call__(self, components, state: PipelineState) -> PipelineState:
|
| 896 |
-
block_state = self.get_block_state(state)
|
| 897 |
-
|
| 898 |
-
tile_size = block_state.tile_size
|
| 899 |
-
tile_padding = block_state.tile_padding
|
| 900 |
-
traversal_mode = block_state.traversal_mode
|
| 901 |
-
|
| 902 |
-
if traversal_mode not in ("linear", "chess"):
|
| 903 |
-
raise ValueError(
|
| 904 |
-
f"Unsupported traversal_mode '{traversal_mode}'. "
|
| 905 |
-
"Supported modes: 'linear', 'chess'."
|
| 906 |
-
)
|
| 907 |
-
|
| 908 |
-
validate_tile_params(tile_size, tile_padding)
|
| 909 |
-
|
| 910 |
-
if traversal_mode == "chess":
|
| 911 |
-
tile_plan = plan_tiles_chess(
|
| 912 |
-
image_width=block_state.upscaled_width,
|
| 913 |
-
image_height=block_state.upscaled_height,
|
| 914 |
-
tile_size=tile_size,
|
| 915 |
-
tile_padding=tile_padding,
|
| 916 |
-
)
|
| 917 |
-
else:
|
| 918 |
-
tile_plan = plan_tiles_linear(
|
| 919 |
-
image_width=block_state.upscaled_width,
|
| 920 |
-
image_height=block_state.upscaled_height,
|
| 921 |
-
tile_size=tile_size,
|
| 922 |
-
tile_padding=tile_padding,
|
| 923 |
-
)
|
| 924 |
-
|
| 925 |
-
# Validate and plan seam-fix bands if enabled
|
| 926 |
-
seam_fix_width = block_state.seam_fix_width
|
| 927 |
-
seam_fix_padding = block_state.seam_fix_padding
|
| 928 |
-
seam_fix_mask_blur = block_state.seam_fix_mask_blur
|
| 929 |
-
|
| 930 |
-
if seam_fix_width < 0:
|
| 931 |
-
raise ValueError(f"`seam_fix_width` must be non-negative, got {seam_fix_width}.")
|
| 932 |
-
if seam_fix_padding < 0:
|
| 933 |
-
raise ValueError(f"`seam_fix_padding` must be non-negative, got {seam_fix_padding}.")
|
| 934 |
-
if seam_fix_mask_blur < 0:
|
| 935 |
-
raise ValueError(f"`seam_fix_mask_blur` must be non-negative, got {seam_fix_mask_blur}.")
|
| 936 |
-
|
| 937 |
-
if seam_fix_width > 0:
|
| 938 |
-
seam_fix_plan = plan_seam_fix_bands(
|
| 939 |
-
tiles=tile_plan,
|
| 940 |
-
image_width=block_state.upscaled_width,
|
| 941 |
-
image_height=block_state.upscaled_height,
|
| 942 |
-
seam_fix_width=seam_fix_width,
|
| 943 |
-
seam_fix_padding=seam_fix_padding,
|
| 944 |
-
)
|
| 945 |
-
else:
|
| 946 |
-
seam_fix_plan = []
|
| 947 |
-
|
| 948 |
-
block_state.tile_plan = tile_plan
|
| 949 |
-
block_state.num_tiles = len(tile_plan)
|
| 950 |
-
block_state.seam_fix_plan = seam_fix_plan
|
| 951 |
-
block_state.seam_fix_mask_blur = seam_fix_mask_blur
|
| 952 |
-
|
| 953 |
-
logger.info(
|
| 954 |
-
f"Planned {len(tile_plan)} tiles "
|
| 955 |
-
f"(tile_size={tile_size}, padding={tile_padding}, traversal={traversal_mode})"
|
| 956 |
-
+ (f", {len(seam_fix_plan)} seam-fix bands" if seam_fix_plan else "")
|
| 957 |
-
)
|
| 958 |
|
| 959 |
self.set_block_state(state, block_state)
|
| 960 |
return components, state
|
| 961 |
|
| 962 |
|
| 963 |
-
# ============================================================
|
| 964 |
-
# denoise
|
| 965 |
-
# ============================================================
|
| 966 |
-
|
| 967 |
-
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 968 |
-
#
|
| 969 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 970 |
-
# you may not use this file except in compliance with the License.
|
| 971 |
-
# You may obtain a copy of the License at
|
| 972 |
-
#
|
| 973 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 974 |
-
#
|
| 975 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 976 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 977 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 978 |
-
# See the License for the specific language governing permissions and
|
| 979 |
-
# limitations under the License.
|
| 980 |
-
|
| 981 |
-
"""Tiled upscaling denoise steps for Modular SDXL Upscale.
|
| 982 |
-
|
| 983 |
-
Architecture follows the ``LoopSequentialPipelineBlocks`` pattern used by the
|
| 984 |
-
SDXL denoising loop. ``UltimateSDUpscaleTileLoopStep`` is the loop wrapper
|
| 985 |
-
(iterates over *tiles*); its sub-blocks are leaf blocks that handle one tile
|
| 986 |
-
per call:
|
| 987 |
-
|
| 988 |
-
TilePrepareStep – crop, VAE encode, prepare latents, tile-aware add_cond
|
| 989 |
-
TileDenoiserStep – full denoising loop (wraps ``StableDiffusionXLDenoiseStep``)
|
| 990 |
-
TilePostProcessStep – decode latents, extract core, paste into canvas
|
| 991 |
-
|
| 992 |
-
SDXL blocks are reused via their public interface by creating temporary
|
| 993 |
-
``PipelineState`` objects, NOT by calling private helpers.
|
| 994 |
-
"""
|
| 995 |
-
|
| 996 |
-
import math
|
| 997 |
-
import time
|
| 998 |
-
|
| 999 |
-
import numpy as np
|
| 1000 |
-
import PIL.Image
|
| 1001 |
-
import torch
|
| 1002 |
-
from tqdm.auto import tqdm
|
| 1003 |
-
|
| 1004 |
-
from diffusers.configuration_utils import FrozenDict
|
| 1005 |
-
from diffusers.guiders import ClassifierFreeGuidance
|
| 1006 |
-
from diffusers.image_processor import VaeImageProcessor
|
| 1007 |
-
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
| 1008 |
-
from diffusers.schedulers import DPMSolverMultistepScheduler, EulerDiscreteScheduler
|
| 1009 |
-
from diffusers.utils import logging
|
| 1010 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 1011 |
-
from diffusers.modular_pipelines.modular_pipeline import (
|
| 1012 |
-
BlockState,
|
| 1013 |
-
LoopSequentialPipelineBlocks,
|
| 1014 |
-
ModularPipelineBlocks,
|
| 1015 |
-
PipelineState,
|
| 1016 |
-
)
|
| 1017 |
-
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
| 1018 |
-
from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import (
|
| 1019 |
-
StableDiffusionXLControlNetInputStep,
|
| 1020 |
-
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
|
| 1021 |
-
StableDiffusionXLImg2ImgPrepareLatentsStep,
|
| 1022 |
-
prepare_latents_img2img,
|
| 1023 |
-
)
|
| 1024 |
-
from diffusers.modular_pipelines.stable_diffusion_xl.decoders import StableDiffusionXLDecodeStep
|
| 1025 |
-
from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep
|
| 1026 |
-
from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLVaeEncoderStep
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
logger = logging.get_logger(__name__)
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
# ---------------------------------------------------------------------------
|
| 1033 |
-
# Helper: populate a PipelineState from a dict
|
| 1034 |
-
# ---------------------------------------------------------------------------
|
| 1035 |
-
|
| 1036 |
-
def _make_state(values: dict, kwargs_type_map: dict | None = None) -> PipelineState:
|
| 1037 |
-
"""Create a PipelineState and set values, optionally with kwargs_type."""
|
| 1038 |
-
state = PipelineState()
|
| 1039 |
-
kwargs_type_map = kwargs_type_map or {}
|
| 1040 |
-
for k, v in values.items():
|
| 1041 |
-
state.set(k, v, kwargs_type_map.get(k))
|
| 1042 |
-
return state
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
def _to_pil_rgb_image(image) -> PIL.Image.Image:
|
| 1046 |
-
"""Convert a tensor/ndarray/PIL image to a RGB PIL image."""
|
| 1047 |
-
if isinstance(image, PIL.Image.Image):
|
| 1048 |
-
return image.convert("RGB")
|
| 1049 |
-
|
| 1050 |
-
if torch.is_tensor(image):
|
| 1051 |
-
tensor = image.detach().cpu()
|
| 1052 |
-
if tensor.ndim == 4:
|
| 1053 |
-
if tensor.shape[0] != 1:
|
| 1054 |
-
raise ValueError(
|
| 1055 |
-
f"`control_image` tensor batch must be 1 for tiled upscaling, got shape {tuple(tensor.shape)}."
|
| 1056 |
-
)
|
| 1057 |
-
tensor = tensor[0]
|
| 1058 |
-
if tensor.ndim == 3 and tensor.shape[0] in (1, 3, 4) and tensor.shape[-1] not in (1, 3, 4):
|
| 1059 |
-
tensor = tensor.permute(1, 2, 0)
|
| 1060 |
-
image = tensor.numpy()
|
| 1061 |
-
|
| 1062 |
-
if isinstance(image, np.ndarray):
|
| 1063 |
-
array = image
|
| 1064 |
-
if array.ndim == 4:
|
| 1065 |
-
if array.shape[0] != 1:
|
| 1066 |
-
raise ValueError(
|
| 1067 |
-
f"`control_image` ndarray batch must be 1 for tiled upscaling, got shape {array.shape}."
|
| 1068 |
-
)
|
| 1069 |
-
array = array[0]
|
| 1070 |
-
if array.ndim == 3 and array.shape[0] in (1, 3, 4) and array.shape[-1] not in (1, 3, 4):
|
| 1071 |
-
array = np.transpose(array, (1, 2, 0))
|
| 1072 |
-
if array.ndim == 2:
|
| 1073 |
-
array = np.stack([array] * 3, axis=-1)
|
| 1074 |
-
if array.ndim != 3:
|
| 1075 |
-
raise ValueError(f"`control_image` must have 2 or 3 dimensions, got shape {array.shape}.")
|
| 1076 |
-
if array.shape[-1] == 1:
|
| 1077 |
-
array = np.repeat(array, 3, axis=-1)
|
| 1078 |
-
if array.shape[-1] == 4:
|
| 1079 |
-
array = array[..., :3]
|
| 1080 |
-
if array.shape[-1] != 3:
|
| 1081 |
-
raise ValueError(f"`control_image` channel dimension must be 1/3/4, got shape {array.shape}.")
|
| 1082 |
-
if array.dtype != np.uint8:
|
| 1083 |
-
array = np.asarray(array, dtype=np.float32)
|
| 1084 |
-
max_val = float(np.max(array)) if array.size > 0 else 1.0
|
| 1085 |
-
if max_val <= 1.0:
|
| 1086 |
-
array = (np.clip(array, 0.0, 1.0) * 255.0).astype(np.uint8)
|
| 1087 |
-
else:
|
| 1088 |
-
array = np.clip(array, 0.0, 255.0).astype(np.uint8)
|
| 1089 |
-
return PIL.Image.fromarray(array).convert("RGB")
|
| 1090 |
-
|
| 1091 |
-
raise ValueError(
|
| 1092 |
-
f"Unsupported `control_image` type {type(image)}. Expected PIL.Image, torch.Tensor, or numpy.ndarray."
|
| 1093 |
-
)
|
| 1094 |
-
|
| 1095 |
-
|
| 1096 |
-
# ---------------------------------------------------------------------------
|
| 1097 |
-
# Scheduler swap helper (Feature 5)
|
| 1098 |
-
# ---------------------------------------------------------------------------
|
| 1099 |
-
|
| 1100 |
-
_SCHEDULER_ALIASES = {
|
| 1101 |
-
"euler": "EulerDiscreteScheduler",
|
| 1102 |
-
"euler discrete": "EulerDiscreteScheduler",
|
| 1103 |
-
"eulerdiscretescheduler": "EulerDiscreteScheduler",
|
| 1104 |
-
"dpm++ 2m": "DPMSolverMultistepScheduler",
|
| 1105 |
-
"dpmsolvermultistepscheduler": "DPMSolverMultistepScheduler",
|
| 1106 |
-
"dpm++ 2m karras": "DPMSolverMultistepScheduler+karras",
|
| 1107 |
-
}
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
def _swap_scheduler(components, scheduler_name: str):
|
| 1111 |
-
"""Swap the scheduler on ``components`` given a human-readable name.
|
| 1112 |
-
|
| 1113 |
-
Supported names (case-insensitive):
|
| 1114 |
-
- ``"Euler"`` / ``"EulerDiscreteScheduler"``
|
| 1115 |
-
- ``"DPM++ 2M"`` / ``"DPMSolverMultistepScheduler"``
|
| 1116 |
-
- ``"DPM++ 2M Karras"`` (DPMSolverMultistep with Karras sigmas)
|
| 1117 |
-
|
| 1118 |
-
If the requested scheduler is already active, this is a no-op.
|
| 1119 |
-
"""
|
| 1120 |
-
key = scheduler_name.strip().lower()
|
| 1121 |
-
resolved = _SCHEDULER_ALIASES.get(key, key)
|
| 1122 |
-
|
| 1123 |
-
use_karras = resolved.endswith("+karras")
|
| 1124 |
-
if use_karras:
|
| 1125 |
-
resolved = resolved.replace("+karras", "")
|
| 1126 |
-
|
| 1127 |
-
current = type(components.scheduler).__name__
|
| 1128 |
-
|
| 1129 |
-
if resolved == "EulerDiscreteScheduler":
|
| 1130 |
-
if current != "EulerDiscreteScheduler":
|
| 1131 |
-
components.scheduler = EulerDiscreteScheduler.from_config(components.scheduler.config)
|
| 1132 |
-
logger.info("Swapped scheduler to EulerDiscreteScheduler")
|
| 1133 |
-
elif resolved == "DPMSolverMultistepScheduler":
|
| 1134 |
-
if current != "DPMSolverMultistepScheduler" or (
|
| 1135 |
-
use_karras and not getattr(components.scheduler.config, "use_karras_sigmas", False)
|
| 1136 |
-
):
|
| 1137 |
-
extra_kwargs = {}
|
| 1138 |
-
if use_karras:
|
| 1139 |
-
extra_kwargs["use_karras_sigmas"] = True
|
| 1140 |
-
components.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 1141 |
-
components.scheduler.config, **extra_kwargs
|
| 1142 |
-
)
|
| 1143 |
-
logger.info(f"Swapped scheduler to DPMSolverMultistepScheduler (karras={use_karras})")
|
| 1144 |
-
else:
|
| 1145 |
-
logger.warning(
|
| 1146 |
-
f"Unknown scheduler_name '{scheduler_name}'. Keeping current scheduler "
|
| 1147 |
-
f"({current}). Supported: 'Euler', 'DPM++ 2M', 'DPM++ 2M Karras'."
|
| 1148 |
-
)
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
# ---------------------------------------------------------------------------
|
| 1152 |
-
# Auto-strength helper (Feature 2)
|
| 1153 |
-
# ---------------------------------------------------------------------------
|
| 1154 |
-
|
| 1155 |
-
def _compute_auto_strength(upscale_factor: float, pass_index: int, num_passes: int) -> float:
|
| 1156 |
-
"""Return the auto-scaled denoise strength for a given pass.
|
| 1157 |
-
|
| 1158 |
-
Rules:
|
| 1159 |
-
- Single-pass 2x: 0.3
|
| 1160 |
-
- Single-pass 4x: 0.15
|
| 1161 |
-
- Progressive passes: first pass=0.3, subsequent passes=0.2
|
| 1162 |
-
"""
|
| 1163 |
-
if num_passes > 1:
|
| 1164 |
-
return 0.3 if pass_index == 0 else 0.2
|
| 1165 |
-
# Single pass
|
| 1166 |
-
if upscale_factor <= 2.0:
|
| 1167 |
-
return 0.3
|
| 1168 |
-
elif upscale_factor <= 4.0:
|
| 1169 |
-
return 0.15
|
| 1170 |
-
else:
|
| 1171 |
-
return 0.1
|
| 1172 |
-
|
| 1173 |
-
|
| 1174 |
-
# ---------------------------------------------------------------------------
|
| 1175 |
-
# Loop sub-block 1: Prepare (crop + encode + timesteps + latents + add_cond)
|
| 1176 |
-
# ---------------------------------------------------------------------------
|
| 1177 |
-
|
| 1178 |
-
class UltimateSDUpscaleTilePrepareStep(ModularPipelineBlocks):
|
| 1179 |
-
"""Loop sub-block that prepares one tile for denoising.
|
| 1180 |
-
|
| 1181 |
-
For each tile it:
|
| 1182 |
-
1. Crops the padded region from the upscaled image.
|
| 1183 |
-
2. Calls ``StableDiffusionXLVaeEncoderStep`` to encode to latents.
|
| 1184 |
-
3. Resets the scheduler step index (reuses timesteps from the outer
|
| 1185 |
-
set_timesteps block — does NOT re-run set_timesteps to avoid
|
| 1186 |
-
double-applying strength).
|
| 1187 |
-
4. Calls ``StableDiffusionXLImg2ImgPrepareLatentsStep``.
|
| 1188 |
-
5. Calls ``StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep``
|
| 1189 |
-
with tile-aware ``crops_coords_top_left`` and ``target_size``.
|
| 1190 |
-
|
| 1191 |
-
All SDXL blocks are reused via their public ``__call__`` interface.
|
| 1192 |
-
"""
|
| 1193 |
-
|
| 1194 |
-
model_name = "stable-diffusion-xl"
|
| 1195 |
-
|
| 1196 |
-
def __init__(self):
|
| 1197 |
-
super().__init__()
|
| 1198 |
-
# Store SDXL blocks as attributes (NOT in sub_blocks → remains a leaf)
|
| 1199 |
-
self._vae_encoder = StableDiffusionXLVaeEncoderStep()
|
| 1200 |
-
self._prepare_latents = StableDiffusionXLImg2ImgPrepareLatentsStep()
|
| 1201 |
-
self._prepare_add_cond = StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep()
|
| 1202 |
-
self._prepare_controlnet = StableDiffusionXLControlNetInputStep()
|
| 1203 |
-
|
| 1204 |
-
@property
|
| 1205 |
-
def description(self) -> str:
|
| 1206 |
-
return (
|
| 1207 |
-
"Loop sub-block: crops a tile, encodes to latents, resets scheduler "
|
| 1208 |
-
"timesteps, prepares latents, and computes tile-aware additional conditioning."
|
| 1209 |
-
)
|
| 1210 |
-
|
| 1211 |
-
@property
|
| 1212 |
-
def expected_components(self) -> list[ComponentSpec]:
|
| 1213 |
-
return [
|
| 1214 |
-
ComponentSpec("vae", AutoencoderKL),
|
| 1215 |
-
ComponentSpec(
|
| 1216 |
-
"image_processor",
|
| 1217 |
-
VaeImageProcessor,
|
| 1218 |
-
config=FrozenDict({"vae_scale_factor": 8}),
|
| 1219 |
-
default_creation_method="from_config",
|
| 1220 |
-
),
|
| 1221 |
-
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
| 1222 |
-
ComponentSpec("unet", UNet2DConditionModel),
|
| 1223 |
-
ComponentSpec(
|
| 1224 |
-
"guider",
|
| 1225 |
-
ClassifierFreeGuidance,
|
| 1226 |
-
config=FrozenDict({"guidance_scale": 7.5}),
|
| 1227 |
-
default_creation_method="from_config",
|
| 1228 |
-
),
|
| 1229 |
-
ComponentSpec(
|
| 1230 |
-
"control_image_processor",
|
| 1231 |
-
VaeImageProcessor,
|
| 1232 |
-
config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}),
|
| 1233 |
-
default_creation_method="from_config",
|
| 1234 |
-
),
|
| 1235 |
-
]
|
| 1236 |
-
|
| 1237 |
-
@property
|
| 1238 |
-
def expected_configs(self) -> list[ConfigSpec]:
|
| 1239 |
-
return [ConfigSpec("requires_aesthetics_score", False)]
|
| 1240 |
-
|
| 1241 |
-
@property
|
| 1242 |
-
def inputs(self) -> list[InputParam]:
|
| 1243 |
-
return [
|
| 1244 |
-
InputParam("upscaled_image", type_hint=PIL.Image.Image, required=True),
|
| 1245 |
-
InputParam("upscaled_height", type_hint=int, required=True),
|
| 1246 |
-
InputParam("upscaled_width", type_hint=int, required=True),
|
| 1247 |
-
InputParam("generator"),
|
| 1248 |
-
InputParam("batch_size", type_hint=int, required=True),
|
| 1249 |
-
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
| 1250 |
-
InputParam("dtype", type_hint=torch.dtype, required=True),
|
| 1251 |
-
InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"),
|
| 1252 |
-
InputParam("num_inference_steps", type_hint=int, default=50),
|
| 1253 |
-
InputParam("strength", type_hint=float, default=0.3),
|
| 1254 |
-
InputParam("timesteps", type_hint=torch.Tensor, required=True),
|
| 1255 |
-
InputParam("latent_timestep", type_hint=torch.Tensor, required=True),
|
| 1256 |
-
InputParam("denoising_start"),
|
| 1257 |
-
InputParam("denoising_end"),
|
| 1258 |
-
InputParam("use_controlnet", type_hint=bool, default=False),
|
| 1259 |
-
InputParam("control_image_processed"),
|
| 1260 |
-
InputParam("control_guidance_start", default=0.0),
|
| 1261 |
-
InputParam("control_guidance_end", default=1.0),
|
| 1262 |
-
InputParam("controlnet_conditioning_scale", default=1.0),
|
| 1263 |
-
InputParam("guess_mode", default=False),
|
| 1264 |
-
]
|
| 1265 |
-
|
| 1266 |
-
@property
|
| 1267 |
-
def intermediate_outputs(self) -> list[OutputParam]:
|
| 1268 |
-
return [
|
| 1269 |
-
OutputParam("latents", type_hint=torch.Tensor),
|
| 1270 |
-
OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
|
| 1271 |
-
OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
|
| 1272 |
-
OutputParam("timestep_cond", type_hint=torch.Tensor),
|
| 1273 |
-
OutputParam("controlnet_cond", type_hint=torch.Tensor),
|
| 1274 |
-
OutputParam("conditioning_scale"),
|
| 1275 |
-
OutputParam("controlnet_keep", type_hint=list[float]),
|
| 1276 |
-
OutputParam("guess_mode", type_hint=bool),
|
| 1277 |
-
]
|
| 1278 |
-
|
| 1279 |
-
@torch.no_grad()
|
| 1280 |
-
def __call__(self, components, block_state: BlockState, tile_idx: int, tile: TileSpec):
|
| 1281 |
-
# --- 1. Crop tile ---
|
| 1282 |
-
tile_image = crop_tile(block_state.upscaled_image, tile)
|
| 1283 |
-
|
| 1284 |
-
# --- 2. VAE encode tile ---
|
| 1285 |
-
enc_state = _make_state({
|
| 1286 |
-
"image": tile_image,
|
| 1287 |
-
"height": tile.crop_h,
|
| 1288 |
-
"width": tile.crop_w,
|
| 1289 |
-
"generator": block_state.generator,
|
| 1290 |
-
"dtype": block_state.dtype,
|
| 1291 |
-
"preprocess_kwargs": None,
|
| 1292 |
-
})
|
| 1293 |
-
components, enc_state = self._vae_encoder(components, enc_state)
|
| 1294 |
-
image_latents = enc_state.get("image_latents")
|
| 1295 |
-
|
| 1296 |
-
# --- 3. Reset scheduler step state for this tile ---
|
| 1297 |
-
# The outer set_timesteps block already computed the correct timesteps
|
| 1298 |
-
# and num_inference_steps (with strength applied). We must NOT re-run
|
| 1299 |
-
# set_timesteps here — that would double-apply strength and produce
|
| 1300 |
-
# 0 denoising steps. Instead, reset the scheduler's mutable step index
|
| 1301 |
-
# so it can iterate the same schedule again for this tile.
|
| 1302 |
-
scheduler = components.scheduler
|
| 1303 |
-
latent_timestep = block_state.latent_timestep
|
| 1304 |
-
|
| 1305 |
-
# Only reset _step_index (progress counter). Do NOT touch _begin_index —
|
| 1306 |
-
# it holds the correct start position computed by the outer set_timesteps
|
| 1307 |
-
# step (e.g., step 14 for strength=0.3 with 20 steps). Resetting it to 0
|
| 1308 |
-
# would make the scheduler use sigmas for full noise (timestep ~999) when
|
| 1309 |
-
# the latents only have partial noise (timestep ~250), producing garbage.
|
| 1310 |
-
if hasattr(scheduler, "_step_index"):
|
| 1311 |
-
scheduler._step_index = None
|
| 1312 |
-
if hasattr(scheduler, "is_scale_input_called"):
|
| 1313 |
-
scheduler.is_scale_input_called = False
|
| 1314 |
-
|
| 1315 |
-
# --- 4. Prepare latents ---
|
| 1316 |
-
# Build clean init latents first (no random noise yet), then add tile noise.
|
| 1317 |
-
# Using a global noise map keeps noise spatially consistent across tiles and
|
| 1318 |
-
# greatly reduces cross-tile drift/artifacts.
|
| 1319 |
-
clean_latents = prepare_latents_img2img(
|
| 1320 |
-
components.vae,
|
| 1321 |
-
components.scheduler,
|
| 1322 |
-
image_latents,
|
| 1323 |
-
latent_timestep,
|
| 1324 |
-
block_state.batch_size,
|
| 1325 |
-
block_state.num_images_per_prompt,
|
| 1326 |
-
block_state.dtype,
|
| 1327 |
-
image_latents.device,
|
| 1328 |
-
generator=None,
|
| 1329 |
-
add_noise=False,
|
| 1330 |
-
)
|
| 1331 |
-
|
| 1332 |
-
latent_h, latent_w = clean_latents.shape[-2], clean_latents.shape[-1]
|
| 1333 |
-
global_noise_map = getattr(block_state, "global_noise_map", None)
|
| 1334 |
-
if global_noise_map is not None:
|
| 1335 |
-
vae_scale_factor = int(getattr(block_state, "global_noise_scale", 8))
|
| 1336 |
-
y0 = max(0, tile.crop_y // vae_scale_factor)
|
| 1337 |
-
x0 = max(0, tile.crop_x // vae_scale_factor)
|
| 1338 |
-
max_y0 = max(0, global_noise_map.shape[-2] - latent_h)
|
| 1339 |
-
max_x0 = max(0, global_noise_map.shape[-1] - latent_w)
|
| 1340 |
-
y0 = min(y0, max_y0)
|
| 1341 |
-
x0 = min(x0, max_x0)
|
| 1342 |
-
tile_noise = global_noise_map[:, :, y0 : y0 + latent_h, x0 : x0 + latent_w]
|
| 1343 |
-
|
| 1344 |
-
# Defensive fallback if latent shape and crop math ever diverge.
|
| 1345 |
-
if tile_noise.shape != clean_latents.shape:
|
| 1346 |
-
tile_noise = randn_tensor(
|
| 1347 |
-
clean_latents.shape,
|
| 1348 |
-
generator=block_state.generator,
|
| 1349 |
-
device=clean_latents.device,
|
| 1350 |
-
dtype=clean_latents.dtype,
|
| 1351 |
-
)
|
| 1352 |
-
else:
|
| 1353 |
-
tile_noise = randn_tensor(
|
| 1354 |
-
clean_latents.shape,
|
| 1355 |
-
generator=block_state.generator,
|
| 1356 |
-
device=clean_latents.device,
|
| 1357 |
-
dtype=clean_latents.dtype,
|
| 1358 |
-
)
|
| 1359 |
-
|
| 1360 |
-
pre_noised_latents = components.scheduler.add_noise(clean_latents, tile_noise, latent_timestep)
|
| 1361 |
-
|
| 1362 |
-
lat_state = _make_state({
|
| 1363 |
-
"image_latents": image_latents,
|
| 1364 |
-
"latent_timestep": latent_timestep,
|
| 1365 |
-
"batch_size": block_state.batch_size,
|
| 1366 |
-
"num_images_per_prompt": block_state.num_images_per_prompt,
|
| 1367 |
-
"dtype": block_state.dtype,
|
| 1368 |
-
"generator": block_state.generator,
|
| 1369 |
-
"latents": pre_noised_latents,
|
| 1370 |
-
"denoising_start": getattr(block_state, "denoising_start", None),
|
| 1371 |
-
})
|
| 1372 |
-
components, lat_state = self._prepare_latents(components, lat_state)
|
| 1373 |
-
|
| 1374 |
-
# --- 5. Prepare additional conditioning (tile-aware) ---
|
| 1375 |
-
# crops_coords_top_left tells SDXL where this tile sits in the canvas
|
| 1376 |
-
# target_size is the tile's pixel dimensions
|
| 1377 |
-
# original_size is the full upscaled image dimensions
|
| 1378 |
-
cond_state = _make_state({
|
| 1379 |
-
"original_size": (block_state.upscaled_height, block_state.upscaled_width),
|
| 1380 |
-
"target_size": (tile.crop_h, tile.crop_w),
|
| 1381 |
-
"crops_coords_top_left": (tile.crop_y, tile.crop_x),
|
| 1382 |
-
"negative_original_size": None,
|
| 1383 |
-
"negative_target_size": None,
|
| 1384 |
-
"negative_crops_coords_top_left": (0, 0),
|
| 1385 |
-
"num_images_per_prompt": block_state.num_images_per_prompt,
|
| 1386 |
-
"aesthetic_score": 6.0,
|
| 1387 |
-
"negative_aesthetic_score": 2.0,
|
| 1388 |
-
"latents": lat_state.get("latents"),
|
| 1389 |
-
"pooled_prompt_embeds": block_state.pooled_prompt_embeds,
|
| 1390 |
-
"batch_size": block_state.batch_size,
|
| 1391 |
-
})
|
| 1392 |
-
components, cond_state = self._prepare_add_cond(components, cond_state)
|
| 1393 |
-
|
| 1394 |
-
# --- Write results to block_state ---
|
| 1395 |
-
# timesteps/num_inference_steps/latent_timestep are from the outer
|
| 1396 |
-
# set_timesteps step (already in block_state), no need to overwrite.
|
| 1397 |
-
block_state.latents = lat_state.get("latents")
|
| 1398 |
-
block_state.add_time_ids = cond_state.get("add_time_ids")
|
| 1399 |
-
block_state.negative_add_time_ids = cond_state.get("negative_add_time_ids")
|
| 1400 |
-
block_state.timestep_cond = cond_state.get("timestep_cond")
|
| 1401 |
-
if getattr(block_state, "use_controlnet", False):
|
| 1402 |
-
control_tile = crop_tile(block_state.control_image_processed, tile)
|
| 1403 |
-
control_state = _make_state({
|
| 1404 |
-
"control_image": control_tile,
|
| 1405 |
-
"control_guidance_start": getattr(block_state, "control_guidance_start", 0.0),
|
| 1406 |
-
"control_guidance_end": getattr(block_state, "control_guidance_end", 1.0),
|
| 1407 |
-
"controlnet_conditioning_scale": getattr(block_state, "controlnet_conditioning_scale", 1.0),
|
| 1408 |
-
"guess_mode": getattr(block_state, "guess_mode", False),
|
| 1409 |
-
"num_images_per_prompt": block_state.num_images_per_prompt,
|
| 1410 |
-
"latents": block_state.latents,
|
| 1411 |
-
"batch_size": block_state.batch_size,
|
| 1412 |
-
"timesteps": block_state.timesteps,
|
| 1413 |
-
"crops_coords": None,
|
| 1414 |
-
})
|
| 1415 |
-
components, control_state = self._prepare_controlnet(components, control_state)
|
| 1416 |
-
block_state.controlnet_cond = control_state.get("controlnet_cond")
|
| 1417 |
-
block_state.conditioning_scale = control_state.get("conditioning_scale")
|
| 1418 |
-
block_state.controlnet_keep = control_state.get("controlnet_keep")
|
| 1419 |
-
block_state.guess_mode = control_state.get("guess_mode")
|
| 1420 |
-
else:
|
| 1421 |
-
block_state.controlnet_cond = None
|
| 1422 |
-
block_state.conditioning_scale = None
|
| 1423 |
-
block_state.controlnet_keep = None
|
| 1424 |
-
|
| 1425 |
-
return components, block_state
|
| 1426 |
-
|
| 1427 |
-
|
| 1428 |
-
# ---------------------------------------------------------------------------
|
| 1429 |
-
# Loop sub-block 2: Denoise
|
| 1430 |
-
# ---------------------------------------------------------------------------
|
| 1431 |
-
|
| 1432 |
-
class UltimateSDUpscaleTileDenoiserStep(ModularPipelineBlocks):
|
| 1433 |
-
"""Loop sub-block that runs the full denoising loop for one tile.
|
| 1434 |
-
|
| 1435 |
-
Wraps ``StableDiffusionXLDenoiseStep`` (itself a
|
| 1436 |
-
``LoopSequentialPipelineBlocks`` over timesteps). Stored as an attribute,
|
| 1437 |
-
not in ``sub_blocks``, so this block remains a leaf.
|
| 1438 |
-
"""
|
| 1439 |
-
|
| 1440 |
-
model_name = "stable-diffusion-xl"
|
| 1441 |
-
|
| 1442 |
-
def __init__(self):
|
| 1443 |
-
super().__init__()
|
| 1444 |
-
self._denoise = StableDiffusionXLDenoiseStep()
|
| 1445 |
-
self._controlnet_denoise = StableDiffusionXLControlNetDenoiseStep()
|
| 1446 |
-
|
| 1447 |
-
@property
|
| 1448 |
-
def description(self) -> str:
|
| 1449 |
-
return (
|
| 1450 |
-
"Loop sub-block: runs the SDXL denoising loop for one tile, "
|
| 1451 |
-
"with optional ControlNet conditioning."
|
| 1452 |
-
)
|
| 1453 |
-
|
| 1454 |
-
@property
|
| 1455 |
-
def expected_components(self) -> list[ComponentSpec]:
|
| 1456 |
-
return [
|
| 1457 |
-
ComponentSpec("unet", UNet2DConditionModel),
|
| 1458 |
-
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
| 1459 |
-
ComponentSpec("controlnet", ControlNetModel),
|
| 1460 |
-
ComponentSpec(
|
| 1461 |
-
"guider",
|
| 1462 |
-
ClassifierFreeGuidance,
|
| 1463 |
-
config=FrozenDict({"guidance_scale": 7.5}),
|
| 1464 |
-
default_creation_method="from_config",
|
| 1465 |
-
),
|
| 1466 |
-
]
|
| 1467 |
-
|
| 1468 |
-
@property
|
| 1469 |
-
def inputs(self) -> list[InputParam]:
|
| 1470 |
-
return [
|
| 1471 |
-
InputParam("latents", type_hint=torch.Tensor, required=True),
|
| 1472 |
-
InputParam("timesteps", type_hint=torch.Tensor, required=True),
|
| 1473 |
-
InputParam("num_inference_steps", type_hint=int, required=True),
|
| 1474 |
-
# Denoiser input fields (kwargs_type must match text encoder outputs)
|
| 1475 |
-
InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"),
|
| 1476 |
-
InputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
|
| 1477 |
-
InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"),
|
| 1478 |
-
InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
|
| 1479 |
-
InputParam("add_time_ids", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"),
|
| 1480 |
-
InputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
|
| 1481 |
-
InputParam("timestep_cond", type_hint=torch.Tensor),
|
| 1482 |
-
InputParam("eta", type_hint=float, default=0.0),
|
| 1483 |
-
InputParam("generator"),
|
| 1484 |
-
InputParam("use_controlnet", type_hint=bool, default=False),
|
| 1485 |
-
InputParam("controlnet_cond", type_hint=torch.Tensor),
|
| 1486 |
-
InputParam("conditioning_scale"),
|
| 1487 |
-
InputParam("controlnet_keep", type_hint=list[float]),
|
| 1488 |
-
InputParam("guess_mode", type_hint=bool, default=False),
|
| 1489 |
-
]
|
| 1490 |
-
|
| 1491 |
-
@property
|
| 1492 |
-
def intermediate_outputs(self) -> list[OutputParam]:
|
| 1493 |
-
return [
|
| 1494 |
-
OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents."),
|
| 1495 |
-
]
|
| 1496 |
-
|
| 1497 |
-
@torch.no_grad()
|
| 1498 |
-
def __call__(self, components, block_state: BlockState, tile_idx: int, tile: TileSpec):
|
| 1499 |
-
# Build a PipelineState with all the data the SDXL denoise step needs
|
| 1500 |
-
denoiser_fields = {
|
| 1501 |
-
"prompt_embeds": block_state.prompt_embeds,
|
| 1502 |
-
"negative_prompt_embeds": getattr(block_state, "negative_prompt_embeds", None),
|
| 1503 |
-
"pooled_prompt_embeds": block_state.pooled_prompt_embeds,
|
| 1504 |
-
"negative_pooled_prompt_embeds": getattr(block_state, "negative_pooled_prompt_embeds", None),
|
| 1505 |
-
"add_time_ids": block_state.add_time_ids,
|
| 1506 |
-
"negative_add_time_ids": getattr(block_state, "negative_add_time_ids", None),
|
| 1507 |
-
}
|
| 1508 |
-
# Add optional fields
|
| 1509 |
-
ip_embeds = getattr(block_state, "ip_adapter_embeds", None)
|
| 1510 |
-
neg_ip_embeds = getattr(block_state, "negative_ip_adapter_embeds", None)
|
| 1511 |
-
if ip_embeds is not None:
|
| 1512 |
-
denoiser_fields["ip_adapter_embeds"] = ip_embeds
|
| 1513 |
-
if neg_ip_embeds is not None:
|
| 1514 |
-
denoiser_fields["negative_ip_adapter_embeds"] = neg_ip_embeds
|
| 1515 |
-
|
| 1516 |
-
kwargs_type_map = {k: "denoiser_input_fields" for k in denoiser_fields}
|
| 1517 |
-
|
| 1518 |
-
all_values = {
|
| 1519 |
-
**denoiser_fields,
|
| 1520 |
-
"latents": block_state.latents,
|
| 1521 |
-
"timesteps": block_state.timesteps,
|
| 1522 |
-
"num_inference_steps": block_state.num_inference_steps,
|
| 1523 |
-
"timestep_cond": getattr(block_state, "timestep_cond", None),
|
| 1524 |
-
"eta": getattr(block_state, "eta", 0.0),
|
| 1525 |
-
"generator": getattr(block_state, "generator", None),
|
| 1526 |
-
}
|
| 1527 |
-
use_controlnet = bool(getattr(block_state, "use_controlnet", False))
|
| 1528 |
-
if use_controlnet:
|
| 1529 |
-
all_values.update(
|
| 1530 |
-
{
|
| 1531 |
-
"controlnet_cond": block_state.controlnet_cond,
|
| 1532 |
-
"conditioning_scale": block_state.conditioning_scale,
|
| 1533 |
-
"guess_mode": getattr(block_state, "guess_mode", False),
|
| 1534 |
-
"controlnet_keep": block_state.controlnet_keep,
|
| 1535 |
-
"controlnet_kwargs": getattr(block_state, "controlnet_kwargs", {}),
|
| 1536 |
-
}
|
| 1537 |
-
)
|
| 1538 |
-
|
| 1539 |
-
denoise_state = _make_state(all_values, kwargs_type_map)
|
| 1540 |
-
if use_controlnet:
|
| 1541 |
-
components, denoise_state = self._controlnet_denoise(components, denoise_state)
|
| 1542 |
-
else:
|
| 1543 |
-
components, denoise_state = self._denoise(components, denoise_state)
|
| 1544 |
-
|
| 1545 |
-
block_state.latents = denoise_state.get("latents")
|
| 1546 |
-
return components, block_state
|
| 1547 |
-
|
| 1548 |
-
|
| 1549 |
-
# ---------------------------------------------------------------------------
|
| 1550 |
-
# Loop sub-block 3: Decode + paste into canvas
|
| 1551 |
-
# ---------------------------------------------------------------------------
|
| 1552 |
-
|
| 1553 |
-
class UltimateSDUpscaleTilePostProcessStep(ModularPipelineBlocks):
|
| 1554 |
-
"""Loop sub-block that decodes one tile and pastes the core into the canvas.
|
| 1555 |
-
|
| 1556 |
-
Supports two blending modes:
|
| 1557 |
-
- ``"none"``: Non-overlapping core paste (fastest, default).
|
| 1558 |
-
- ``"gradient"``: Gradient overlap blending for smoother tile transitions.
|
| 1559 |
-
"""
|
| 1560 |
-
|
| 1561 |
-
model_name = "stable-diffusion-xl"
|
| 1562 |
-
|
| 1563 |
-
def __init__(self):
|
| 1564 |
-
super().__init__()
|
| 1565 |
-
self._decode = StableDiffusionXLDecodeStep()
|
| 1566 |
-
|
| 1567 |
-
@property
|
| 1568 |
-
def description(self) -> str:
|
| 1569 |
-
return (
|
| 1570 |
-
"Loop sub-block: decodes latents to an image via StableDiffusionXLDecodeStep, "
|
| 1571 |
-
"then extracts the core region and pastes it into the output canvas. "
|
| 1572 |
-
"Supports 'none' and 'gradient' blending modes."
|
| 1573 |
-
)
|
| 1574 |
-
|
| 1575 |
-
@property
|
| 1576 |
-
def expected_components(self) -> list[ComponentSpec]:
|
| 1577 |
-
return [
|
| 1578 |
-
ComponentSpec("vae", AutoencoderKL),
|
| 1579 |
-
ComponentSpec(
|
| 1580 |
-
"image_processor",
|
| 1581 |
-
VaeImageProcessor,
|
| 1582 |
-
config=FrozenDict({"vae_scale_factor": 8}),
|
| 1583 |
-
default_creation_method="from_config",
|
| 1584 |
-
),
|
| 1585 |
-
]
|
| 1586 |
-
|
| 1587 |
-
@property
|
| 1588 |
-
def inputs(self) -> list[InputParam]:
|
| 1589 |
-
return [
|
| 1590 |
-
InputParam("latents", type_hint=torch.Tensor, required=True),
|
| 1591 |
-
]
|
| 1592 |
-
|
| 1593 |
-
@property
|
| 1594 |
-
def intermediate_outputs(self) -> list[OutputParam]:
|
| 1595 |
-
return [] # Canvas is modified in-place on block_state
|
| 1596 |
-
|
| 1597 |
-
@torch.no_grad()
|
| 1598 |
-
def __call__(self, components, block_state: BlockState, tile_idx: int, tile: TileSpec):
|
| 1599 |
-
decode_state = _make_state({
|
| 1600 |
-
"latents": block_state.latents,
|
| 1601 |
-
"output_type": "np",
|
| 1602 |
-
})
|
| 1603 |
-
components, decode_state = self._decode(components, decode_state)
|
| 1604 |
-
decoded_images = decode_state.get("images")
|
| 1605 |
-
|
| 1606 |
-
decoded_np = decoded_images[0] # shape: (crop_h, crop_w, 3)
|
| 1607 |
-
|
| 1608 |
-
if decoded_np.shape[0] != tile.crop_h or decoded_np.shape[1] != tile.crop_w:
|
| 1609 |
-
pil_tile = PIL.Image.fromarray((np.clip(decoded_np, 0, 1) * 255).astype(np.uint8))
|
| 1610 |
-
pil_tile = pil_tile.resize((tile.crop_w, tile.crop_h), PIL.Image.LANCZOS)
|
| 1611 |
-
decoded_np = np.array(pil_tile).astype(np.float32) / 255.0
|
| 1612 |
-
|
| 1613 |
-
core = extract_core_from_decoded(decoded_np, tile)
|
| 1614 |
-
|
| 1615 |
-
blend_mode = getattr(block_state, "blend_mode", "none")
|
| 1616 |
-
if blend_mode == "gradient":
|
| 1617 |
-
overlap = getattr(block_state, "gradient_blend_overlap", 0)
|
| 1618 |
-
paste_core_into_canvas_blended(
|
| 1619 |
-
block_state.canvas, block_state.weight_map, core, tile, overlap
|
| 1620 |
-
)
|
| 1621 |
-
elif blend_mode == "none":
|
| 1622 |
-
paste_core_into_canvas(block_state.canvas, core, tile)
|
| 1623 |
-
else:
|
| 1624 |
-
raise ValueError(
|
| 1625 |
-
f"Unsupported blend_mode '{blend_mode}'. "
|
| 1626 |
-
"Supported modes: 'none', 'gradient'."
|
| 1627 |
-
)
|
| 1628 |
-
|
| 1629 |
-
return components, block_state
|
| 1630 |
-
|
| 1631 |
-
|
| 1632 |
-
# ---------------------------------------------------------------------------
|
| 1633 |
-
# Tile loop wrapper (LoopSequentialPipelineBlocks)
|
| 1634 |
-
# ---------------------------------------------------------------------------
|
| 1635 |
-
|
| 1636 |
-
class UltimateSDUpscaleTileLoopStep(LoopSequentialPipelineBlocks):
|
| 1637 |
-
"""Tile loop that iterates over the tile plan, running sub-blocks per tile.
|
| 1638 |
-
|
| 1639 |
-
Supports:
|
| 1640 |
-
- Two blending modes: ``"none"`` (core paste) and ``"gradient"`` (overlap blending)
|
| 1641 |
-
- Optional seam-fix pass: re-denoises narrow bands along tile boundaries
|
| 1642 |
-
with feathered mask blending
|
| 1643 |
-
|
| 1644 |
-
Sub-blocks:
|
| 1645 |
-
- ``UltimateSDUpscaleTilePrepareStep`` – crop, encode, prepare
|
| 1646 |
-
- ``UltimateSDUpscaleTileDenoiserStep`` – denoising loop
|
| 1647 |
-
- ``UltimateSDUpscaleTilePostProcessStep`` – decode + paste
|
| 1648 |
-
"""
|
| 1649 |
-
|
| 1650 |
-
model_name = "stable-diffusion-xl"
|
| 1651 |
-
|
| 1652 |
-
block_classes = [
|
| 1653 |
-
UltimateSDUpscaleTilePrepareStep,
|
| 1654 |
-
UltimateSDUpscaleTileDenoiserStep,
|
| 1655 |
-
UltimateSDUpscaleTilePostProcessStep,
|
| 1656 |
-
]
|
| 1657 |
-
block_names = ["tile_prepare", "tile_denoise", "tile_postprocess"]
|
| 1658 |
-
|
| 1659 |
-
@property
|
| 1660 |
-
def description(self) -> str:
|
| 1661 |
-
return (
|
| 1662 |
-
"Tile loop that iterates over the tile plan and runs sub-blocks per tile.\n"
|
| 1663 |
-
"Supports 'none' and 'gradient' blending modes, plus optional seam-fix pass.\n"
|
| 1664 |
-
"Sub-blocks:\n"
|
| 1665 |
-
" - UltimateSDUpscaleTilePrepareStep: crop, VAE encode, set timesteps, "
|
| 1666 |
-
"prepare latents, tile-aware add_cond\n"
|
| 1667 |
-
" - UltimateSDUpscaleTileDenoiserStep: SDXL denoising loop\n"
|
| 1668 |
-
" - UltimateSDUpscaleTilePostProcessStep: decode + paste core into canvas"
|
| 1669 |
-
)
|
| 1670 |
-
|
| 1671 |
-
@property
|
| 1672 |
-
def loop_inputs(self) -> list[InputParam]:
|
| 1673 |
-
return [
|
| 1674 |
-
InputParam("tile_plan", type_hint=list, required=True,
|
| 1675 |
-
description="List of TileSpec from the tile planning step."),
|
| 1676 |
-
InputParam("upscaled_image", type_hint=PIL.Image.Image, required=True),
|
| 1677 |
-
InputParam("upscaled_height", type_hint=int, required=True),
|
| 1678 |
-
InputParam("upscaled_width", type_hint=int, required=True),
|
| 1679 |
-
InputParam("tile_padding", type_hint=int, default=32),
|
| 1680 |
-
InputParam("output_type", type_hint=str, default="pil"),
|
| 1681 |
-
InputParam("blend_mode", type_hint=str, default="none",
|
| 1682 |
-
description="Blending mode: 'none' (core paste) or 'gradient' (overlap blending)."),
|
| 1683 |
-
InputParam("gradient_blend_overlap", type_hint=int, default=16,
|
| 1684 |
-
description="Width of gradient ramp in pixels for 'gradient' blend mode."),
|
| 1685 |
-
InputParam("seam_fix_plan", type_hint=list, default=[],
|
| 1686 |
-
description="List of SeamFixSpec from tile planning. Empty disables seam fix."),
|
| 1687 |
-
InputParam("seam_fix_mask_blur", type_hint=int, default=8,
|
| 1688 |
-
description="Feathering width for seam-fix band blending."),
|
| 1689 |
-
InputParam("seam_fix_strength", type_hint=float, default=0.3,
|
| 1690 |
-
description="Denoise strength for seam-fix bands."),
|
| 1691 |
-
InputParam("control_image",
|
| 1692 |
-
description="Optional ControlNet conditioning image. If provided, tile denoising uses ControlNet."),
|
| 1693 |
-
InputParam("control_guidance_start", default=0.0),
|
| 1694 |
-
InputParam("control_guidance_end", default=1.0),
|
| 1695 |
-
InputParam("controlnet_conditioning_scale", default=1.0),
|
| 1696 |
-
InputParam("guess_mode", default=False),
|
| 1697 |
-
InputParam("guidance_scale", type_hint=float, default=7.5,
|
| 1698 |
-
description="Classifier-Free Guidance scale. Higher values produce images more aligned "
|
| 1699 |
-
"with the prompt at the expense of lower image quality."),
|
| 1700 |
-
]
|
| 1701 |
-
|
| 1702 |
-
@property
|
| 1703 |
-
def loop_intermediate_outputs(self) -> list[OutputParam]:
|
| 1704 |
-
return [
|
| 1705 |
-
OutputParam("images", type_hint=list, description="Final stitched output images."),
|
| 1706 |
-
]
|
| 1707 |
-
|
| 1708 |
-
def _run_seam_fix_band(self, components, block_state, band: SeamFixSpec, band_idx: int):
|
| 1709 |
-
"""Re-denoise one seam-fix band and blend it into the canvas."""
|
| 1710 |
-
# Crop the band region directly from the float canvas to avoid
|
| 1711 |
-
# full-canvas uint8 quantization per band (quality + perf).
|
| 1712 |
-
crop_region = np.clip(
|
| 1713 |
-
block_state.canvas[band.crop_y:band.crop_y + band.crop_h,
|
| 1714 |
-
band.crop_x:band.crop_x + band.crop_w],
|
| 1715 |
-
0, 1,
|
| 1716 |
-
)
|
| 1717 |
-
crop_uint8 = (crop_region * 255).astype(np.uint8)
|
| 1718 |
-
band_crop_pil = PIL.Image.fromarray(crop_uint8)
|
| 1719 |
-
|
| 1720 |
-
# The PIL image is the crop region only, so the tile spec must use
|
| 1721 |
-
# 0-based coordinates (the entire image IS the crop).
|
| 1722 |
-
band_tile = TileSpec(
|
| 1723 |
-
core_x=band.paste_x, core_y=band.paste_y,
|
| 1724 |
-
core_w=band.band_w, core_h=band.band_h,
|
| 1725 |
-
crop_x=0, crop_y=0,
|
| 1726 |
-
crop_w=band.crop_w, crop_h=band.crop_h,
|
| 1727 |
-
paste_x=band.paste_x, paste_y=band.paste_y,
|
| 1728 |
-
)
|
| 1729 |
|
| 1730 |
-
|
| 1731 |
-
|
| 1732 |
-
|
| 1733 |
-
|
| 1734 |
-
|
| 1735 |
-
|
| 1736 |
-
|
| 1737 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1738 |
|
| 1739 |
-
|
| 1740 |
-
original_strength = block_state.strength
|
| 1741 |
-
block_state.strength = getattr(block_state, "seam_fix_strength", 0.3)
|
| 1742 |
|
| 1743 |
-
|
| 1744 |
-
|
| 1745 |
-
|
| 1746 |
|
| 1747 |
-
|
| 1748 |
-
|
| 1749 |
|
| 1750 |
-
|
| 1751 |
-
|
| 1752 |
-
|
| 1753 |
-
|
| 1754 |
-
})
|
| 1755 |
-
decode_block = self.sub_blocks["tile_postprocess"]._decode
|
| 1756 |
-
components, decode_state = decode_block(components, decode_state)
|
| 1757 |
-
decoded_np = decode_state.get("images")[0]
|
| 1758 |
|
| 1759 |
-
|
| 1760 |
-
|
| 1761 |
-
|
| 1762 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1763 |
|
| 1764 |
-
# Extract and paste band with feathered mask
|
| 1765 |
-
band_pixels = extract_band_from_decoded(decoded_np, band)
|
| 1766 |
-
seam_fix_mask_blur = getattr(block_state, "seam_fix_mask_blur", 8)
|
| 1767 |
-
paste_seam_fix_band(block_state.canvas, band_pixels, band, seam_fix_mask_blur)
|
| 1768 |
|
| 1769 |
-
|
| 1770 |
-
block_state.upscaled_image = original_image
|
| 1771 |
-
if getattr(block_state, "use_controlnet", False):
|
| 1772 |
-
block_state.control_image_processed = original_control_image
|
| 1773 |
-
block_state.strength = original_strength
|
| 1774 |
|
| 1775 |
-
return components, block_state
|
| 1776 |
|
| 1777 |
-
|
| 1778 |
-
|
| 1779 |
-
|
| 1780 |
|
| 1781 |
-
|
| 1782 |
-
|
| 1783 |
-
|
| 1784 |
-
|
| 1785 |
-
|
| 1786 |
-
|
| 1787 |
-
|
| 1788 |
-
f"Unsupported blend_mode '{blend_mode}'. Supported: 'none', 'gradient'."
|
| 1789 |
-
)
|
| 1790 |
|
| 1791 |
-
# --- Configure guidance_scale on guider ---
|
| 1792 |
-
guidance_scale = getattr(block_state, "guidance_scale", 7.5)
|
| 1793 |
-
components.guider.guidance_scale = guidance_scale
|
| 1794 |
|
| 1795 |
-
|
| 1796 |
-
|
| 1797 |
-
|
| 1798 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1799 |
raise ValueError(
|
| 1800 |
-
"
|
| 1801 |
)
|
| 1802 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1803 |
raise ValueError(
|
| 1804 |
-
"`control_image`
|
| 1805 |
-
"Load a ControlNet model (for example, a tile model) into `pipe.controlnet`."
|
| 1806 |
)
|
| 1807 |
-
|
| 1808 |
-
|
| 1809 |
-
|
| 1810 |
-
|
| 1811 |
-
|
| 1812 |
-
|
| 1813 |
-
|
| 1814 |
-
|
| 1815 |
-
|
| 1816 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1817 |
|
| 1818 |
-
|
| 1819 |
-
|
|
|
|
| 1820 |
|
| 1821 |
-
# Prepare one global latent noise tensor and crop from it per tile.
|
| 1822 |
-
# This keeps stochasticity consistent across tile boundaries.
|
| 1823 |
-
vae_scale_factor = int(getattr(components, "vae_scale_factor", 8))
|
| 1824 |
-
latent_h = max(1, h // vae_scale_factor)
|
| 1825 |
-
latent_w = max(1, w // vae_scale_factor)
|
| 1826 |
-
effective_batch = block_state.batch_size * block_state.num_images_per_prompt
|
| 1827 |
-
block_state.global_noise_map = randn_tensor(
|
| 1828 |
-
(effective_batch, 4, latent_h, latent_w),
|
| 1829 |
-
generator=getattr(block_state, "generator", None),
|
| 1830 |
-
device=components._execution_device,
|
| 1831 |
-
dtype=block_state.dtype,
|
| 1832 |
-
)
|
| 1833 |
-
block_state.global_noise_scale = vae_scale_factor
|
| 1834 |
|
| 1835 |
-
|
| 1836 |
-
|
| 1837 |
-
|
| 1838 |
-
block_state.gradient_blend_overlap = getattr(block_state, "gradient_blend_overlap", 16)
|
| 1839 |
|
| 1840 |
-
|
| 1841 |
-
|
| 1842 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1843 |
|
| 1844 |
-
logger.info(
|
| 1845 |
-
f"Processing {num_tiles} tiles"
|
| 1846 |
-
+ (f" (blend_mode={blend_mode})" if blend_mode != "none" else "")
|
| 1847 |
-
+ (f" + {len(seam_fix_plan)} seam-fix bands" if seam_fix_plan else "")
|
| 1848 |
-
)
|
| 1849 |
|
| 1850 |
-
|
| 1851 |
-
|
| 1852 |
-
for i, tile in enumerate(tile_plan):
|
| 1853 |
-
logger.debug(
|
| 1854 |
-
f"Tile {i + 1}/{num_tiles}: core=({tile.core_x},{tile.core_y},{tile.core_w},{tile.core_h}) "
|
| 1855 |
-
f"crop=({tile.crop_x},{tile.crop_y},{tile.crop_w},{tile.crop_h})"
|
| 1856 |
-
)
|
| 1857 |
-
components, block_state = self.loop_step(components, block_state, tile_idx=i, tile=tile)
|
| 1858 |
-
progress_bar.update()
|
| 1859 |
-
|
| 1860 |
-
# Finalize gradient blending before seam fix
|
| 1861 |
-
if blend_mode == "gradient":
|
| 1862 |
-
block_state.canvas = finalize_blended_canvas(block_state.canvas, block_state.weight_map)
|
| 1863 |
-
|
| 1864 |
-
# Seam-fix pass
|
| 1865 |
-
for j, band in enumerate(seam_fix_plan):
|
| 1866 |
-
logger.debug(
|
| 1867 |
-
f"Seam-fix {j + 1}/{len(seam_fix_plan)}: "
|
| 1868 |
-
f"band=({band.band_x},{band.band_y},{band.band_w},{band.band_h}) "
|
| 1869 |
-
f"{band.orientation}"
|
| 1870 |
-
)
|
| 1871 |
-
components, block_state = self._run_seam_fix_band(components, block_state, band, j)
|
| 1872 |
-
progress_bar.update()
|
| 1873 |
|
| 1874 |
-
|
| 1875 |
-
|
| 1876 |
-
|
|
|
|
| 1877 |
|
| 1878 |
-
|
| 1879 |
-
|
| 1880 |
-
|
| 1881 |
-
|
| 1882 |
-
elif output_type == "pt":
|
| 1883 |
-
block_state.images = [torch.from_numpy(result).permute(2, 0, 1).unsqueeze(0)]
|
| 1884 |
-
else:
|
| 1885 |
-
block_state.images = [PIL.Image.fromarray(result_uint8)]
|
| 1886 |
|
| 1887 |
-
|
| 1888 |
-
|
|
|
|
| 1889 |
|
|
|
|
| 1890 |
|
| 1891 |
-
|
| 1892 |
-
|
| 1893 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1894 |
|
| 1895 |
|
| 1896 |
-
|
| 1897 |
-
|
| 1898 |
-
|
| 1899 |
-
is_left: bool = False, is_right: bool = False,
|
| 1900 |
-
) -> torch.Tensor:
|
| 1901 |
-
"""Create a boundary-aware 2D cosine-ramp weight for MultiDiffusion blending.
|
| 1902 |
-
|
| 1903 |
-
Weight is 1.0 in the center and smoothly fades at edges that overlap with
|
| 1904 |
-
neighboring tiles. Edges that touch the image boundary keep weight=1.0 to
|
| 1905 |
-
prevent noise amplification from dividing by near-zero weights.
|
| 1906 |
-
|
| 1907 |
-
Args:
|
| 1908 |
-
h: Tile height in latent pixels.
|
| 1909 |
-
w: Tile width in latent pixels.
|
| 1910 |
-
overlap: Overlap in latent pixels.
|
| 1911 |
-
device: Torch device.
|
| 1912 |
-
dtype: Torch dtype.
|
| 1913 |
-
is_top: True if this tile touches the top image boundary.
|
| 1914 |
-
is_bottom: True if this tile touches the bottom image boundary.
|
| 1915 |
-
is_left: True if this tile touches the left image boundary.
|
| 1916 |
-
is_right: True if this tile touches the right image boundary.
|
| 1917 |
-
|
| 1918 |
-
Returns:
|
| 1919 |
-
Tensor of shape ``(1, 1, h, w)`` for broadcasting.
|
| 1920 |
-
"""
|
| 1921 |
-
def _ramp(length, overlap_size, keep_start, keep_end):
|
| 1922 |
-
ramp = torch.ones(length, device=device, dtype=dtype)
|
| 1923 |
-
if overlap_size > 0 and length > 2 * overlap_size:
|
| 1924 |
-
fade = 0.5 * (1.0 - torch.cos(torch.linspace(0, math.pi, overlap_size, device=device, dtype=dtype)))
|
| 1925 |
-
if not keep_start:
|
| 1926 |
-
ramp[:overlap_size] = fade
|
| 1927 |
-
if not keep_end:
|
| 1928 |
-
ramp[-overlap_size:] = fade.flip(0)
|
| 1929 |
-
return ramp
|
| 1930 |
|
| 1931 |
-
|
| 1932 |
-
|
| 1933 |
-
return (w_h[:, None] * w_w[None, :]).unsqueeze(0).unsqueeze(0)
|
| 1934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1935 |
|
| 1936 |
class UltimateSDUpscaleMultiDiffusionStep(ModularPipelineBlocks):
|
| 1937 |
"""Single block that encodes, denoises with MultiDiffusion, and decodes.
|
|
@@ -2580,42 +1102,7 @@ class UltimateSDUpscaleMultiDiffusionStep(ModularPipelineBlocks):
|
|
| 2580 |
# modular_blocks
|
| 2581 |
# ============================================================
|
| 2582 |
|
| 2583 |
-
|
| 2584 |
-
#
|
| 2585 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 2586 |
-
# you may not use this file except in compliance with the License.
|
| 2587 |
-
# You may obtain a copy of the License at
|
| 2588 |
-
#
|
| 2589 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 2590 |
-
#
|
| 2591 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 2592 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 2593 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 2594 |
-
# See the License for the specific language governing permissions and
|
| 2595 |
-
# limitations under the License.
|
| 2596 |
-
|
| 2597 |
-
"""Top-level block composition for Modular SDXL Upscale.
|
| 2598 |
-
|
| 2599 |
-
The pipeline preserves the standard SDXL block graph as closely as
|
| 2600 |
-
possible, inserting upscale and tile-plan steps and wrapping the per-tile
|
| 2601 |
-
work in a ``LoopSequentialPipelineBlocks``::
|
| 2602 |
-
|
| 2603 |
-
text_encoder → upscale → tile_plan → input → set_timesteps → tiled_img2img
|
| 2604 |
-
|
| 2605 |
-
Inside ``tiled_img2img`` (tile loop), each tile runs:
|
| 2606 |
-
|
| 2607 |
-
tile_prepare → tile_denoise → tile_postprocess
|
| 2608 |
-
|
| 2609 |
-
Followed by an optional seam-fix pass that re-denoises narrow bands along
|
| 2610 |
-
tile boundaries with feathered mask blending.
|
| 2611 |
-
|
| 2612 |
-
Features:
|
| 2613 |
-
- Linear and chess (checkerboard) tile traversal
|
| 2614 |
-
- Non-overlapping core paste or gradient overlap blending
|
| 2615 |
-
- Optional seam-fix band re-denoise with configurable width and mask blur
|
| 2616 |
-
- Optional ControlNet tile conditioning for stronger cross-tile structure consistency
|
| 2617 |
-
- Tile-aware SDXL micro-conditioning (crops_coords_top_left per tile)
|
| 2618 |
-
"""
|
| 2619 |
|
| 2620 |
from diffusers.utils import logging
|
| 2621 |
from diffusers.modular_pipelines.modular_pipeline import SequentialPipelineBlocks
|
|
@@ -2629,85 +1116,22 @@ from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import (
|
|
| 2629 |
logger = logging.get_logger(__name__)
|
| 2630 |
|
| 2631 |
|
| 2632 |
-
class UltimateSDUpscaleBlocks(SequentialPipelineBlocks):
|
| 2633 |
-
"""Modular pipeline blocks for tiled SDXL upscaling.
|
| 2634 |
-
|
| 2635 |
-
Block graph::
|
| 2636 |
-
|
| 2637 |
-
[0] text_encoder – StableDiffusionXLTextEncoderStep (reused)
|
| 2638 |
-
[1] upscale – UltimateSDUpscaleUpscaleStep (new)
|
| 2639 |
-
[2] tile_plan – UltimateSDUpscaleTilePlanStep (new)
|
| 2640 |
-
[3] input – StableDiffusionXLInputStep (reused)
|
| 2641 |
-
[4] set_timesteps – StableDiffusionXLImg2ImgSetTimestepsStep (reused)
|
| 2642 |
-
[5] tiled_img2img – UltimateSDUpscaleTileLoopStep (tile loop + seam fix)
|
| 2643 |
-
|
| 2644 |
-
Features:
|
| 2645 |
-
- Linear and chess (checkerboard) tile traversal
|
| 2646 |
-
- Non-overlapping core paste or gradient overlap blending
|
| 2647 |
-
- Seam-fix band re-denoise with feathered mask blending
|
| 2648 |
-
- Tile-aware SDXL conditioning (crops_coords_top_left per tile)
|
| 2649 |
-
"""
|
| 2650 |
-
|
| 2651 |
-
block_classes = [
|
| 2652 |
-
UltimateSDUpscaleTextEncoderStep,
|
| 2653 |
-
UltimateSDUpscaleUpscaleStep,
|
| 2654 |
-
UltimateSDUpscaleTilePlanStep,
|
| 2655 |
-
StableDiffusionXLInputStep,
|
| 2656 |
-
StableDiffusionXLImg2ImgSetTimestepsStep,
|
| 2657 |
-
UltimateSDUpscaleTileLoopStep,
|
| 2658 |
-
]
|
| 2659 |
-
block_names = [
|
| 2660 |
-
"text_encoder",
|
| 2661 |
-
"upscale",
|
| 2662 |
-
"tile_plan",
|
| 2663 |
-
"input",
|
| 2664 |
-
"set_timesteps",
|
| 2665 |
-
"tiled_img2img",
|
| 2666 |
-
]
|
| 2667 |
-
|
| 2668 |
-
_workflow_map = {
|
| 2669 |
-
"upscale": {"image": True, "prompt": True},
|
| 2670 |
-
"upscale_controlnet": {"image": True, "control_image": True, "prompt": True},
|
| 2671 |
-
}
|
| 2672 |
-
|
| 2673 |
-
@property
|
| 2674 |
-
def description(self):
|
| 2675 |
-
return (
|
| 2676 |
-
"Modular tiled upscaling pipeline for Stable Diffusion XL.\n"
|
| 2677 |
-
"Upscales an input image and refines it using tiled denoising.\n"
|
| 2678 |
-
"Default: single-pass mode (tile_size=2048) — seamless, no tile artifacts.\n"
|
| 2679 |
-
"For very large images: set tile_size=512 for tiled mode with optional "
|
| 2680 |
-
"chess traversal, gradient blending, seam-fix, and ControlNet tile conditioning."
|
| 2681 |
-
)
|
| 2682 |
-
|
| 2683 |
-
@property
|
| 2684 |
-
def outputs(self):
|
| 2685 |
-
return [OutputParam.template("images")]
|
| 2686 |
-
|
| 2687 |
-
|
| 2688 |
class MultiDiffusionUpscaleBlocks(SequentialPipelineBlocks):
|
| 2689 |
"""Modular pipeline blocks for tiled SDXL upscaling with MultiDiffusion.
|
| 2690 |
|
| 2691 |
Uses latent-space noise prediction blending across overlapping tiles for
|
| 2692 |
-
|
| 2693 |
-
block set for high-quality upscaling.
|
| 2694 |
|
| 2695 |
Block graph::
|
| 2696 |
|
| 2697 |
-
[0] text_encoder
|
| 2698 |
-
[1] upscale
|
| 2699 |
-
[2] input
|
| 2700 |
-
[3] set_timesteps
|
| 2701 |
-
[4] multidiffusion
|
| 2702 |
|
| 2703 |
The MultiDiffusion step handles VAE encode, tiled denoise with blending,
|
| 2704 |
and VAE decode internally, using VAE tiling for memory efficiency.
|
| 2705 |
-
|
| 2706 |
-
Features:
|
| 2707 |
-
- Seamless output at any resolution (no tile boundary artifacts)
|
| 2708 |
-
- Optional ControlNet Tile conditioning
|
| 2709 |
-
- Configurable latent tile size and overlap
|
| 2710 |
-
- Single-pass for small images, tiled for large images
|
| 2711 |
"""
|
| 2712 |
|
| 2713 |
block_classes = [
|
|
|
|
| 8 |
# utils_tiling
|
| 9 |
# ============================================================
|
| 10 |
|
| 11 |
+
"""Tile planning and cosine blending weights for MultiDiffusion."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
@dataclass
|
| 19 |
class LatentTileSpec:
|
| 20 |
+
"""Tile specification in latent space.
|
| 21 |
|
| 22 |
Attributes:
|
| 23 |
y: Top edge in latent pixels.
|
|
|
|
| 32 |
w: int
|
| 33 |
|
| 34 |
|
| 35 |
+
def validate_tile_params(tile_size: int, overlap: int) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
if tile_size <= 0:
|
| 37 |
raise ValueError(f"`tile_size` must be positive, got {tile_size}.")
|
| 38 |
if overlap < 0:
|
|
|
|
| 43 |
f"Got overlap={overlap}, tile_size={tile_size}."
|
| 44 |
)
|
| 45 |
|
| 46 |
+
|
| 47 |
+
def plan_latent_tiles(
|
| 48 |
+
latent_h: int,
|
| 49 |
+
latent_w: int,
|
| 50 |
+
tile_size: int = 64,
|
| 51 |
+
overlap: int = 8,
|
| 52 |
+
) -> list[LatentTileSpec]:
|
| 53 |
+
"""Plan overlapping tiles in latent space for MultiDiffusion.
|
| 54 |
+
|
| 55 |
+
Tiles overlap by ``overlap`` latent pixels. Edge tiles are clamped to
|
| 56 |
+
the latent bounds.
|
| 57 |
+
"""
|
| 58 |
+
validate_tile_params(tile_size, overlap)
|
| 59 |
+
|
| 60 |
stride = tile_size - overlap
|
| 61 |
tiles: list[LatentTileSpec] = []
|
| 62 |
|
| 63 |
y = 0
|
| 64 |
while y < latent_h:
|
| 65 |
h = min(tile_size, latent_h - y)
|
|
|
|
| 66 |
if h < tile_size and y > 0:
|
| 67 |
y = max(0, latent_h - tile_size)
|
| 68 |
h = latent_h - y
|
|
|
|
| 87 |
return tiles
|
| 88 |
|
| 89 |
|
| 90 |
+
def make_cosine_tile_weight(
|
| 91 |
+
h: int,
|
| 92 |
+
w: int,
|
| 93 |
+
overlap: int,
|
| 94 |
+
device: torch.device,
|
| 95 |
+
dtype: torch.dtype,
|
| 96 |
+
is_top: bool = False,
|
| 97 |
+
is_bottom: bool = False,
|
| 98 |
+
is_left: bool = False,
|
| 99 |
+
is_right: bool = False,
|
| 100 |
+
) -> torch.Tensor:
|
| 101 |
+
"""Boundary-aware cosine blending weight for one tile.
|
| 102 |
+
|
| 103 |
+
Returns shape (1, 1, h, w). Canvas-edge sides get weight 1.0 (no fade),
|
| 104 |
+
interior overlap regions get a half-cosine ramp from 0 to 1.
|
| 105 |
+
"""
|
| 106 |
+
import math
|
| 107 |
+
|
| 108 |
+
wy = torch.ones(h, device=device, dtype=dtype)
|
| 109 |
+
wx = torch.ones(w, device=device, dtype=dtype)
|
| 110 |
+
|
| 111 |
+
ramp = min(overlap, h // 2, w // 2)
|
| 112 |
+
if ramp <= 0:
|
| 113 |
+
return torch.ones(1, 1, h, w, device=device, dtype=dtype)
|
| 114 |
+
|
| 115 |
+
cos_ramp = torch.tensor(
|
| 116 |
+
[0.5 * (1 - math.cos(math.pi * i / ramp)) for i in range(ramp)],
|
| 117 |
+
device=device,
|
| 118 |
+
dtype=dtype,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
if not is_top:
|
| 122 |
+
wy[:ramp] = cos_ramp
|
| 123 |
+
if not is_bottom:
|
| 124 |
+
wy[-ramp:] = cos_ramp.flip(0)
|
| 125 |
+
if not is_left:
|
| 126 |
+
wx[:ramp] = cos_ramp
|
| 127 |
+
if not is_right:
|
| 128 |
+
wx[-ramp:] = cos_ramp.flip(0)
|
| 129 |
+
|
| 130 |
+
weight = wy[:, None] * wx[None, :]
|
| 131 |
+
return weight.unsqueeze(0).unsqueeze(0)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
# ============================================================
|
| 135 |
# input
|
| 136 |
# ============================================================
|
|
|
|
| 149 |
# See the License for the specific language governing permissions and
|
| 150 |
# limitations under the License.
|
| 151 |
|
| 152 |
+
"""Input steps for Modular SDXL Upscale: text encoding, Lanczos upscale."""
|
| 153 |
+
|
| 154 |
import PIL.Image
|
| 155 |
import torch
|
| 156 |
|
|
|
|
| 166 |
class UltimateSDUpscaleTextEncoderStep(StableDiffusionXLTextEncoderStep):
|
| 167 |
"""SDXL text encoder step that applies guidance scale before encoding.
|
| 168 |
|
| 169 |
+
Syncs the guider's guidance_scale before prompt encoding so that
|
| 170 |
+
unconditional embeddings are always produced when CFG is active.
|
|
|
|
| 171 |
|
| 172 |
+
Also applies a default negative prompt for upscaling when the user
|
| 173 |
+
does not provide one.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
"""
|
| 175 |
|
| 176 |
DEFAULT_NEGATIVE_PROMPT = "blurry, low quality, artifacts, noise, jpeg compression"
|
| 177 |
|
| 178 |
@property
|
| 179 |
def inputs(self) -> list[InputParam]:
|
|
|
|
| 180 |
return super().inputs + [
|
| 181 |
InputParam(
|
| 182 |
"guidance_scale",
|
| 183 |
type_hint=float,
|
| 184 |
default=7.5,
|
| 185 |
+
description="Classifier-Free Guidance scale.",
|
|
|
|
|
|
|
|
|
|
| 186 |
),
|
| 187 |
InputParam(
|
| 188 |
"use_default_negative",
|
| 189 |
type_hint=bool,
|
| 190 |
default=True,
|
| 191 |
+
description="Apply default negative prompt when none is provided.",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
),
|
| 193 |
]
|
| 194 |
|
|
|
|
| 200 |
if hasattr(components, "guider") and components.guider is not None:
|
| 201 |
components.guider.guidance_scale = guidance_scale
|
| 202 |
|
|
|
|
| 203 |
use_default_negative = getattr(block_state, "use_default_negative", True)
|
| 204 |
if use_default_negative:
|
| 205 |
neg = getattr(block_state, "negative_prompt", None)
|
|
|
|
| 211 |
|
| 212 |
|
| 213 |
class UltimateSDUpscaleUpscaleStep(ModularPipelineBlocks):
|
| 214 |
+
"""Upscales the input image using Lanczos interpolation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
@property
|
| 217 |
def description(self) -> str:
|
| 218 |
+
return "Upscale input image using Lanczos interpolation."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
@property
|
| 221 |
def inputs(self) -> list[InputParam]:
|
| 222 |
return [
|
| 223 |
+
InputParam("image", type_hint=PIL.Image.Image, required=True,
|
| 224 |
+
description="Input image to upscale."),
|
| 225 |
+
InputParam("upscale_factor", type_hint=float, default=2.0,
|
| 226 |
+
description="Scale multiplier."),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
]
|
| 228 |
|
| 229 |
@property
|
| 230 |
def intermediate_outputs(self) -> list[OutputParam]:
|
| 231 |
return [
|
| 232 |
+
OutputParam("upscaled_image", type_hint=PIL.Image.Image),
|
| 233 |
+
OutputParam("upscaled_width", type_hint=int),
|
| 234 |
+
OutputParam("upscaled_height", type_hint=int),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
]
|
| 236 |
|
| 237 |
@torch.no_grad()
|
|
|
|
| 242 |
upscale_factor = block_state.upscale_factor
|
| 243 |
|
| 244 |
if not isinstance(image, PIL.Image.Image):
|
| 245 |
+
raise ValueError(f"Expected PIL.Image, got {type(image)}.")
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
new_width = int(image.width * upscale_factor)
|
| 248 |
new_height = int(image.height * upscale_factor)
|
|
|
|
| 251 |
block_state.upscaled_width = new_width
|
| 252 |
block_state.upscaled_height = new_height
|
| 253 |
|
| 254 |
+
logger.info(f"Upscaled {image.width}x{image.height} -> {new_width}x{new_height}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
self.set_block_state(state, block_state)
|
| 257 |
return components, state
|
| 258 |
|
| 259 |
|
| 260 |
+
# ============================================================
|
| 261 |
+
# denoise
|
| 262 |
+
# ============================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 265 |
+
#
|
| 266 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 267 |
+
# you may not use this file except in compliance with the License.
|
| 268 |
+
# You may obtain a copy of the License at
|
| 269 |
+
#
|
| 270 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 271 |
+
#
|
| 272 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 273 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 274 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 275 |
+
# See the License for the specific language governing permissions and
|
| 276 |
+
# limitations under the License.
|
| 277 |
|
| 278 |
+
"""MultiDiffusion tiled upscaling step for Modular SDXL Upscale.
|
|
|
|
|
|
|
| 279 |
|
| 280 |
+
Blends noise predictions from overlapping latent tiles using cosine weights.
|
| 281 |
+
Reuses SDXL blocks via their public interface.
|
| 282 |
+
"""
|
| 283 |
|
| 284 |
+
import math
|
| 285 |
+
import time
|
| 286 |
|
| 287 |
+
import numpy as np
|
| 288 |
+
import PIL.Image
|
| 289 |
+
import torch
|
| 290 |
+
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
+
from diffusers.configuration_utils import FrozenDict
|
| 293 |
+
from diffusers.guiders import ClassifierFreeGuidance
|
| 294 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 295 |
+
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
| 296 |
+
from diffusers.schedulers import DPMSolverMultistepScheduler, EulerDiscreteScheduler
|
| 297 |
+
from diffusers.utils import logging
|
| 298 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 299 |
+
from diffusers.modular_pipelines.modular_pipeline import (
|
| 300 |
+
ModularPipelineBlocks,
|
| 301 |
+
PipelineState,
|
| 302 |
+
)
|
| 303 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
| 304 |
+
from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import (
|
| 305 |
+
StableDiffusionXLControlNetInputStep,
|
| 306 |
+
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
|
| 307 |
+
StableDiffusionXLImg2ImgPrepareLatentsStep,
|
| 308 |
+
prepare_latents_img2img,
|
| 309 |
+
)
|
| 310 |
+
from diffusers.modular_pipelines.stable_diffusion_xl.decoders import StableDiffusionXLDecodeStep
|
| 311 |
+
from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLVaeEncoderStep
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
+
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
|
|
|
| 316 |
|
| 317 |
+
# ---------------------------------------------------------------------------
|
| 318 |
+
# Helper: populate a PipelineState from a dict
|
| 319 |
+
# ---------------------------------------------------------------------------
|
| 320 |
|
| 321 |
+
def _make_state(values: dict, kwargs_type_map: dict | None = None) -> PipelineState:
|
| 322 |
+
"""Create a PipelineState and set values, optionally with kwargs_type."""
|
| 323 |
+
state = PipelineState()
|
| 324 |
+
kwargs_type_map = kwargs_type_map or {}
|
| 325 |
+
for k, v in values.items():
|
| 326 |
+
state.set(k, v, kwargs_type_map.get(k))
|
| 327 |
+
return state
|
|
|
|
|
|
|
| 328 |
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
+
def _to_pil_rgb_image(image) -> PIL.Image.Image:
|
| 331 |
+
"""Convert a tensor/ndarray/PIL image to a RGB PIL image."""
|
| 332 |
+
if isinstance(image, PIL.Image.Image):
|
| 333 |
+
return image.convert("RGB")
|
| 334 |
+
|
| 335 |
+
if torch.is_tensor(image):
|
| 336 |
+
tensor = image.detach().cpu()
|
| 337 |
+
if tensor.ndim == 4:
|
| 338 |
+
if tensor.shape[0] != 1:
|
| 339 |
raise ValueError(
|
| 340 |
+
f"`control_image` tensor batch must be 1 for tiled upscaling, got shape {tuple(tensor.shape)}."
|
| 341 |
)
|
| 342 |
+
tensor = tensor[0]
|
| 343 |
+
if tensor.ndim == 3 and tensor.shape[0] in (1, 3, 4) and tensor.shape[-1] not in (1, 3, 4):
|
| 344 |
+
tensor = tensor.permute(1, 2, 0)
|
| 345 |
+
image = tensor.numpy()
|
| 346 |
+
|
| 347 |
+
if isinstance(image, np.ndarray):
|
| 348 |
+
array = image
|
| 349 |
+
if array.ndim == 4:
|
| 350 |
+
if array.shape[0] != 1:
|
| 351 |
raise ValueError(
|
| 352 |
+
f"`control_image` ndarray batch must be 1 for tiled upscaling, got shape {array.shape}."
|
|
|
|
| 353 |
)
|
| 354 |
+
array = array[0]
|
| 355 |
+
if array.ndim == 3 and array.shape[0] in (1, 3, 4) and array.shape[-1] not in (1, 3, 4):
|
| 356 |
+
array = np.transpose(array, (1, 2, 0))
|
| 357 |
+
if array.ndim == 2:
|
| 358 |
+
array = np.stack([array] * 3, axis=-1)
|
| 359 |
+
if array.ndim != 3:
|
| 360 |
+
raise ValueError(f"`control_image` must have 2 or 3 dimensions, got shape {array.shape}.")
|
| 361 |
+
if array.shape[-1] == 1:
|
| 362 |
+
array = np.repeat(array, 3, axis=-1)
|
| 363 |
+
if array.shape[-1] == 4:
|
| 364 |
+
array = array[..., :3]
|
| 365 |
+
if array.shape[-1] != 3:
|
| 366 |
+
raise ValueError(f"`control_image` channel dimension must be 1/3/4, got shape {array.shape}.")
|
| 367 |
+
if array.dtype != np.uint8:
|
| 368 |
+
array = np.asarray(array, dtype=np.float32)
|
| 369 |
+
max_val = float(np.max(array)) if array.size > 0 else 1.0
|
| 370 |
+
if max_val <= 1.0:
|
| 371 |
+
array = (np.clip(array, 0.0, 1.0) * 255.0).astype(np.uint8)
|
| 372 |
+
else:
|
| 373 |
+
array = np.clip(array, 0.0, 255.0).astype(np.uint8)
|
| 374 |
+
return PIL.Image.fromarray(array).convert("RGB")
|
| 375 |
|
| 376 |
+
raise ValueError(
|
| 377 |
+
f"Unsupported `control_image` type {type(image)}. Expected PIL.Image, torch.Tensor, or numpy.ndarray."
|
| 378 |
+
)
|
| 379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
+
# ---------------------------------------------------------------------------
|
| 382 |
+
# Scheduler swap helper (Feature 5)
|
| 383 |
+
# ---------------------------------------------------------------------------
|
|
|
|
| 384 |
|
| 385 |
+
_SCHEDULER_ALIASES = {
|
| 386 |
+
"euler": "EulerDiscreteScheduler",
|
| 387 |
+
"euler discrete": "EulerDiscreteScheduler",
|
| 388 |
+
"eulerdiscretescheduler": "EulerDiscreteScheduler",
|
| 389 |
+
"dpm++ 2m": "DPMSolverMultistepScheduler",
|
| 390 |
+
"dpmsolvermultistepscheduler": "DPMSolverMultistepScheduler",
|
| 391 |
+
"dpm++ 2m karras": "DPMSolverMultistepScheduler+karras",
|
| 392 |
+
}
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
+
def _swap_scheduler(components, scheduler_name: str):
|
| 396 |
+
"""Swap the scheduler on ``components`` given a human-readable name.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
+
Supported names (case-insensitive):
|
| 399 |
+
- ``"Euler"`` / ``"EulerDiscreteScheduler"``
|
| 400 |
+
- ``"DPM++ 2M"`` / ``"DPMSolverMultistepScheduler"``
|
| 401 |
+
- ``"DPM++ 2M Karras"`` (DPMSolverMultistep with Karras sigmas)
|
| 402 |
|
| 403 |
+
If the requested scheduler is already active, this is a no-op.
|
| 404 |
+
"""
|
| 405 |
+
key = scheduler_name.strip().lower()
|
| 406 |
+
resolved = _SCHEDULER_ALIASES.get(key, key)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
+
use_karras = resolved.endswith("+karras")
|
| 409 |
+
if use_karras:
|
| 410 |
+
resolved = resolved.replace("+karras", "")
|
| 411 |
|
| 412 |
+
current = type(components.scheduler).__name__
|
| 413 |
|
| 414 |
+
if resolved == "EulerDiscreteScheduler":
|
| 415 |
+
if current != "EulerDiscreteScheduler":
|
| 416 |
+
components.scheduler = EulerDiscreteScheduler.from_config(components.scheduler.config)
|
| 417 |
+
logger.info("Swapped scheduler to EulerDiscreteScheduler")
|
| 418 |
+
elif resolved == "DPMSolverMultistepScheduler":
|
| 419 |
+
if current != "DPMSolverMultistepScheduler" or (
|
| 420 |
+
use_karras and not getattr(components.scheduler.config, "use_karras_sigmas", False)
|
| 421 |
+
):
|
| 422 |
+
extra_kwargs = {}
|
| 423 |
+
if use_karras:
|
| 424 |
+
extra_kwargs["use_karras_sigmas"] = True
|
| 425 |
+
components.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 426 |
+
components.scheduler.config, **extra_kwargs
|
| 427 |
+
)
|
| 428 |
+
logger.info(f"Swapped scheduler to DPMSolverMultistepScheduler (karras={use_karras})")
|
| 429 |
+
else:
|
| 430 |
+
logger.warning(
|
| 431 |
+
f"Unknown scheduler_name '{scheduler_name}'. Keeping current scheduler "
|
| 432 |
+
f"({current}). Supported: 'Euler', 'DPM++ 2M', 'DPM++ 2M Karras'."
|
| 433 |
+
)
|
| 434 |
|
| 435 |
|
| 436 |
+
# ---------------------------------------------------------------------------
|
| 437 |
+
# Auto-strength helper (Feature 2)
|
| 438 |
+
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
+
def _compute_auto_strength(upscale_factor: float, pass_index: int, num_passes: int) -> float:
|
| 441 |
+
"""Return the auto-scaled denoise strength for a given pass.
|
|
|
|
| 442 |
|
| 443 |
+
Rules:
|
| 444 |
+
- Single-pass 2x: 0.3
|
| 445 |
+
- Single-pass 4x: 0.15
|
| 446 |
+
- Progressive passes: first pass=0.3, subsequent passes=0.2
|
| 447 |
+
"""
|
| 448 |
+
if num_passes > 1:
|
| 449 |
+
return 0.3 if pass_index == 0 else 0.2
|
| 450 |
+
# Single pass
|
| 451 |
+
if upscale_factor <= 2.0:
|
| 452 |
+
return 0.3
|
| 453 |
+
elif upscale_factor <= 4.0:
|
| 454 |
+
return 0.15
|
| 455 |
+
else:
|
| 456 |
+
return 0.1
|
| 457 |
|
| 458 |
class UltimateSDUpscaleMultiDiffusionStep(ModularPipelineBlocks):
|
| 459 |
"""Single block that encodes, denoises with MultiDiffusion, and decodes.
|
|
|
|
| 1102 |
# modular_blocks
|
| 1103 |
# ============================================================
|
| 1104 |
|
| 1105 |
+
"""Block composition for Modular SDXL Upscale."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1106 |
|
| 1107 |
from diffusers.utils import logging
|
| 1108 |
from diffusers.modular_pipelines.modular_pipeline import SequentialPipelineBlocks
|
|
|
|
| 1116 |
logger = logging.get_logger(__name__)
|
| 1117 |
|
| 1118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1119 |
class MultiDiffusionUpscaleBlocks(SequentialPipelineBlocks):
|
| 1120 |
"""Modular pipeline blocks for tiled SDXL upscaling with MultiDiffusion.
|
| 1121 |
|
| 1122 |
Uses latent-space noise prediction blending across overlapping tiles for
|
| 1123 |
+
seamless tiled upscaling at any resolution.
|
|
|
|
| 1124 |
|
| 1125 |
Block graph::
|
| 1126 |
|
| 1127 |
+
[0] text_encoder - SDXL TextEncoderStep (reused)
|
| 1128 |
+
[1] upscale - Lanczos resize
|
| 1129 |
+
[2] input - SDXL InputStep (reused)
|
| 1130 |
+
[3] set_timesteps - SDXL Img2Img SetTimestepsStep (reused)
|
| 1131 |
+
[4] multidiffusion - MultiDiffusion step
|
| 1132 |
|
| 1133 |
The MultiDiffusion step handles VAE encode, tiled denoise with blending,
|
| 1134 |
and VAE decode internally, using VAE tiling for memory efficiency.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1135 |
"""
|
| 1136 |
|
| 1137 |
block_classes = [
|