akshan-main commited on
Commit
7183ccf
·
verified ·
1 Parent(s): ac159db

Upload block.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. block.py +254 -1830
block.py CHANGED
@@ -8,589 +8,16 @@ MultiDiffusion tiled upscaling for Stable Diffusion XL using Modular Diffusers.
8
  # utils_tiling
9
  # ============================================================
10
 
11
- # Copyright 2025 The HuggingFace Team. All rights reserved.
12
- #
13
- # Licensed under the Apache License, Version 2.0 (the "License");
14
- # you may not use this file except in compliance with the License.
15
- # You may obtain a copy of the License at
16
- #
17
- # http://www.apache.org/licenses/LICENSE-2.0
18
- #
19
- # Unless required by applicable law or agreed to in writing, software
20
- # distributed under the License is distributed on an "AS IS" BASIS,
21
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
- # See the License for the specific language governing permissions and
23
- # limitations under the License.
24
-
25
- """Pure utility functions for tiled upscale workflows.
26
-
27
- Supports:
28
- - Linear (raster) and chess (checkerboard) tile traversal
29
- - Non-overlapping core paste and gradient overlap blending
30
- - Seam-fix band planning along tile boundaries
31
- - Linear feathered mask blending for seam-fix bands
32
- """
33
-
34
- from dataclasses import dataclass, field
35
-
36
- import numpy as np
37
- import PIL.Image
38
-
39
-
40
- @dataclass
41
- class TileSpec:
42
- """Specification for a single tile, distinguishing the core output region
43
- from the padded crop region used for denoising.
44
-
45
- Attributes:
46
- core_x: Left edge of the core region in the output canvas.
47
- core_y: Top edge of the core region in the output canvas.
48
- core_w: Width of the core region (what this tile is responsible for pasting).
49
- core_h: Height of the core region.
50
- crop_x: Left edge of the padded crop region in the source image.
51
- crop_y: Top edge of the padded crop region in the source image.
52
- crop_w: Width of the padded crop region (what gets denoised).
53
- crop_h: Height of the padded crop region.
54
- paste_x: X offset of the core region within the crop (left padding amount).
55
- paste_y: Y offset of the core region within the crop (top padding amount).
56
- """
57
-
58
- core_x: int
59
- core_y: int
60
- core_w: int
61
- core_h: int
62
- crop_x: int
63
- crop_y: int
64
- crop_w: int
65
- crop_h: int
66
- paste_x: int
67
- paste_y: int
68
-
69
-
70
- @dataclass
71
- class SeamFixSpec:
72
- """Specification for a seam-fix band along a tile boundary.
73
-
74
- Attributes:
75
- band_x: Left edge of the band in the output canvas.
76
- band_y: Top edge of the band in the output canvas.
77
- band_w: Width of the band.
78
- band_h: Height of the band.
79
- crop_x: Left edge of the padded crop for denoising.
80
- crop_y: Top edge of the padded crop for denoising.
81
- crop_w: Width of the padded crop.
82
- crop_h: Height of the padded crop.
83
- paste_x: X offset of the band within the crop.
84
- paste_y: Y offset of the band within the crop.
85
- orientation: 'horizontal' or 'vertical'.
86
- """
87
-
88
- band_x: int
89
- band_y: int
90
- band_w: int
91
- band_h: int
92
- crop_x: int
93
- crop_y: int
94
- crop_w: int
95
- crop_h: int
96
- paste_x: int
97
- paste_y: int
98
- orientation: str = field(default="horizontal")
99
-
100
-
101
- def validate_tile_params(tile_size: int, tile_padding: int) -> None:
102
- """Validate tile parameters strictly.
103
-
104
- Args:
105
- tile_size: Base tile size in pixels.
106
- tile_padding: Overlap padding on each side.
107
-
108
- Raises:
109
- ValueError: If parameters are out of range.
110
- """
111
- if tile_size <= 0:
112
- raise ValueError(f"`tile_size` must be positive, got {tile_size}.")
113
- if tile_padding < 0:
114
- raise ValueError(f"`tile_padding` must be non-negative, got {tile_padding}.")
115
- if tile_padding >= tile_size // 2:
116
- raise ValueError(
117
- f"`tile_padding` must be less than tile_size // 2. "
118
- f"Got tile_padding={tile_padding}, tile_size={tile_size} "
119
- f"(max allowed: {tile_size // 2 - 1})."
120
- )
121
-
122
-
123
- def plan_tiles_linear(
124
- image_width: int,
125
- image_height: int,
126
- tile_size: int = 512,
127
- tile_padding: int = 32,
128
- ) -> list[TileSpec]:
129
- """Plan tiles in a left-to-right, top-to-bottom (linear/raster) traversal order.
130
-
131
- Each tile is a ``TileSpec`` with separate core (output responsibility) and
132
- crop (denoised region with padding context) bounds. The crop region extends
133
- beyond the core by ``tile_padding`` on each side, clamped to image edges.
134
-
135
- Args:
136
- image_width: Width of the image to tile.
137
- image_height: Height of the image to tile.
138
- tile_size: Base tile size. The core region of each tile is
139
- ``tile_size - 2 * tile_padding``.
140
- tile_padding: Number of overlap pixels on each side.
141
-
142
- Returns:
143
- List of ``TileSpec`` in linear traversal order.
144
- """
145
- validate_tile_params(tile_size, tile_padding)
146
-
147
- core_size = tile_size - 2 * tile_padding
148
- tiles: list[TileSpec] = []
149
-
150
- core_y = 0
151
- while core_y < image_height:
152
- core_h = min(core_size, image_height - core_y)
153
-
154
- core_x = 0
155
- while core_x < image_width:
156
- core_w = min(core_size, image_width - core_x)
157
-
158
- # Compute padded crop region, clamped to image bounds
159
- crop_x = max(0, core_x - tile_padding)
160
- crop_y = max(0, core_y - tile_padding)
161
- crop_x2 = min(image_width, core_x + core_w + tile_padding)
162
- crop_y2 = min(image_height, core_y + core_h + tile_padding)
163
- crop_w = crop_x2 - crop_x
164
- crop_h = crop_y2 - crop_y
165
-
166
- # Where the core sits within the crop
167
- paste_x = core_x - crop_x
168
- paste_y = core_y - crop_y
169
-
170
- tiles.append(
171
- TileSpec(
172
- core_x=core_x,
173
- core_y=core_y,
174
- core_w=core_w,
175
- core_h=core_h,
176
- crop_x=crop_x,
177
- crop_y=crop_y,
178
- crop_w=crop_w,
179
- crop_h=crop_h,
180
- paste_x=paste_x,
181
- paste_y=paste_y,
182
- )
183
- )
184
-
185
- core_x += core_size
186
- core_y += core_size
187
-
188
- return tiles
189
-
190
-
191
- def crop_tile(image: PIL.Image.Image, tile: TileSpec) -> PIL.Image.Image:
192
- """Crop the padded region of a tile from a PIL image.
193
-
194
- Args:
195
- image: Source image.
196
- tile: Tile specification.
197
-
198
- Returns:
199
- Cropped PIL image of the padded crop region.
200
- """
201
- return image.crop((tile.crop_x, tile.crop_y, tile.crop_x + tile.crop_w, tile.crop_y + tile.crop_h))
202
-
203
-
204
- def extract_core_from_decoded(decoded_image: np.ndarray, tile: TileSpec) -> np.ndarray:
205
- """Extract the core region from a decoded tile image.
206
-
207
- Args:
208
- decoded_image: Decoded tile as numpy array, shape (crop_h, crop_w, C).
209
- tile: Tile specification.
210
-
211
- Returns:
212
- Core region as numpy array, shape (core_h, core_w, C).
213
- """
214
- return decoded_image[
215
- tile.paste_y : tile.paste_y + tile.core_h,
216
- tile.paste_x : tile.paste_x + tile.core_w,
217
- ]
218
-
219
-
220
- def paste_core_into_canvas(
221
- canvas: np.ndarray,
222
- core_image: np.ndarray,
223
- tile: TileSpec,
224
- ) -> None:
225
- """Paste the core region of a decoded tile directly into the output canvas.
226
-
227
- No blending — the core regions tile the canvas without overlap.
228
-
229
- Args:
230
- canvas: Output canvas, shape (H, W, C), float32. Modified in-place.
231
- core_image: Core tile pixels, shape (core_h, core_w, C), float32.
232
- tile: Tile specification.
233
- """
234
- canvas[tile.core_y : tile.core_y + tile.core_h, tile.core_x : tile.core_x + tile.core_w] = core_image
235
-
236
-
237
- # =============================================================================
238
- # Chess (checkerboard) traversal
239
- # =============================================================================
240
-
241
-
242
- def plan_tiles_chess(
243
- image_width: int,
244
- image_height: int,
245
- tile_size: int = 512,
246
- tile_padding: int = 32,
247
- ) -> list[TileSpec]:
248
- """Plan tiles in a checkerboard (chess) traversal order.
249
-
250
- Two passes: first all "white" tiles (row+col both even or both odd),
251
- then all "black" tiles. This ensures adjacent tiles are never processed
252
- consecutively, reducing visible seam patterns.
253
-
254
- Args:
255
- image_width: Width of the image to tile.
256
- image_height: Height of the image to tile.
257
- tile_size: Base tile size.
258
- tile_padding: Number of overlap pixels on each side.
259
-
260
- Returns:
261
- List of ``TileSpec`` in chess traversal order.
262
- """
263
- validate_tile_params(tile_size, tile_padding)
264
-
265
- core_size = tile_size - 2 * tile_padding
266
-
267
- # Build grid of all tiles with (row, col) indices
268
- grid: list[tuple[int, int, TileSpec]] = []
269
-
270
- row = 0
271
- core_y = 0
272
- while core_y < image_height:
273
- core_h = min(core_size, image_height - core_y)
274
-
275
- col = 0
276
- core_x = 0
277
- while core_x < image_width:
278
- core_w = min(core_size, image_width - core_x)
279
-
280
- crop_x = max(0, core_x - tile_padding)
281
- crop_y = max(0, core_y - tile_padding)
282
- crop_x2 = min(image_width, core_x + core_w + tile_padding)
283
- crop_y2 = min(image_height, core_y + core_h + tile_padding)
284
- crop_w = crop_x2 - crop_x
285
- crop_h = crop_y2 - crop_y
286
-
287
- paste_x = core_x - crop_x
288
- paste_y = core_y - crop_y
289
-
290
- tile = TileSpec(
291
- core_x=core_x, core_y=core_y, core_w=core_w, core_h=core_h,
292
- crop_x=crop_x, crop_y=crop_y, crop_w=crop_w, crop_h=crop_h,
293
- paste_x=paste_x, paste_y=paste_y,
294
- )
295
- grid.append((row, col, tile))
296
-
297
- col += 1
298
- core_x += core_size
299
- row += 1
300
- core_y += core_size
301
-
302
- # Separate into white and black squares
303
- white = [t for r, c, t in grid if (r + c) % 2 == 0]
304
- black = [t for r, c, t in grid if (r + c) % 2 == 1]
305
-
306
- return white + black
307
-
308
-
309
- # =============================================================================
310
- # Gradient overlap blending
311
- # =============================================================================
312
-
313
-
314
- def make_gradient_mask(
315
- core_h: int,
316
- core_w: int,
317
- overlap: int,
318
- at_top: bool = False,
319
- at_bottom: bool = False,
320
- at_left: bool = False,
321
- at_right: bool = False,
322
- ) -> np.ndarray:
323
- """Create a boundary-aware gradient blending mask for a tile's core region.
324
-
325
- The mask is 1.0 in the interior and linearly ramps from 0 to 1 in the
326
- overlap zones along interior edges only. Edges that touch the canvas
327
- boundary (indicated by ``at_*`` flags) stay at 1.0 to prevent black borders.
328
-
329
- Args:
330
- core_h: Height of the core region.
331
- core_w: Width of the core region.
332
- overlap: Width of the gradient ramp in pixels.
333
- at_top: True if tile is at the top edge of the canvas.
334
- at_bottom: True if tile is at the bottom edge of the canvas.
335
- at_left: True if tile is at the left edge of the canvas.
336
- at_right: True if tile is at the right edge of the canvas.
337
-
338
- Returns:
339
- Mask of shape (core_h, core_w), float32, values in [0, 1].
340
- """
341
- if overlap <= 0:
342
- return np.ones((core_h, core_w), dtype=np.float32)
343
-
344
- mask = np.ones((core_h, core_w), dtype=np.float32)
345
-
346
- # Only fade on interior edges (not canvas boundaries)
347
- ramp_w = min(overlap, core_w)
348
- if ramp_w > 0 and not at_left:
349
- left_ramp = np.linspace(0.0, 1.0, ramp_w, dtype=np.float32)
350
- mask[:, :ramp_w] = np.minimum(mask[:, :ramp_w], left_ramp[np.newaxis, :])
351
- if ramp_w > 0 and not at_right:
352
- right_ramp = np.linspace(1.0, 0.0, ramp_w, dtype=np.float32)
353
- mask[:, -ramp_w:] = np.minimum(mask[:, -ramp_w:], right_ramp[np.newaxis, :])
354
-
355
- ramp_h = min(overlap, core_h)
356
- if ramp_h > 0 and not at_top:
357
- top_ramp = np.linspace(0.0, 1.0, ramp_h, dtype=np.float32)
358
- mask[:ramp_h, :] = np.minimum(mask[:ramp_h, :], top_ramp[:, np.newaxis])
359
- if ramp_h > 0 and not at_bottom:
360
- bottom_ramp = np.linspace(1.0, 0.0, ramp_h, dtype=np.float32)
361
- mask[-ramp_h:, :] = np.minimum(mask[-ramp_h:, :], bottom_ramp[:, np.newaxis])
362
-
363
- return mask
364
-
365
-
366
- def paste_core_into_canvas_blended(
367
- canvas: np.ndarray,
368
- weight_map: np.ndarray,
369
- core_image: np.ndarray,
370
- tile: TileSpec,
371
- overlap: int,
372
- ) -> None:
373
- """Paste a tile's core into the canvas using boundary-aware gradient blending.
374
-
375
- Uses accumulated weighted sum approach: canvas stores weighted sum,
376
- weight_map stores total weights. Finalize by dividing canvas / weight_map.
377
-
378
- Args:
379
- canvas: Output canvas, shape (H, W, C), float32. Modified in-place.
380
- weight_map: Weight accumulator, shape (H, W), float32. Modified in-place.
381
- core_image: Core tile pixels, shape (core_h, core_w, C), float32.
382
- tile: Tile specification.
383
- overlap: Gradient overlap width in pixels.
384
- """
385
- canvas_h, canvas_w = canvas.shape[:2]
386
-
387
- mask = make_gradient_mask(
388
- tile.core_h, tile.core_w, overlap,
389
- at_top=(tile.core_y == 0),
390
- at_bottom=(tile.core_y + tile.core_h >= canvas_h),
391
- at_left=(tile.core_x == 0),
392
- at_right=(tile.core_x + tile.core_w >= canvas_w),
393
- )
394
-
395
- y1, y2 = tile.core_y, tile.core_y + tile.core_h
396
- x1, x2 = tile.core_x, tile.core_x + tile.core_w
397
-
398
- canvas[y1:y2, x1:x2] += core_image * mask[:, :, np.newaxis]
399
- weight_map[y1:y2, x1:x2] += mask
400
-
401
-
402
- def finalize_blended_canvas(canvas: np.ndarray, weight_map: np.ndarray) -> np.ndarray:
403
- """Normalize the blended canvas by dividing by accumulated weights.
404
-
405
- Pixels with zero weight (uncovered) are filled from the raw weighted sum
406
- to avoid black borders from epsilon division.
407
-
408
- Args:
409
- canvas: Weighted sum canvas, shape (H, W, C).
410
- weight_map: Weight accumulator, shape (H, W).
411
-
412
- Returns:
413
- Normalized canvas, shape (H, W, C), float32.
414
- """
415
- result = np.copy(canvas)
416
- covered = weight_map > 0
417
- result[covered] = canvas[covered] / weight_map[covered, np.newaxis]
418
- # Uncovered pixels stay as-is (zero) — should not occur with boundary-aware masks
419
- return result
420
-
421
-
422
- # =============================================================================
423
- # Seam-fix band planning
424
- # =============================================================================
425
-
426
-
427
- def plan_seam_fix_bands(
428
- tiles: list[TileSpec],
429
- image_width: int,
430
- image_height: int,
431
- seam_fix_width: int = 64,
432
- seam_fix_padding: int = 16,
433
- ) -> list[SeamFixSpec]:
434
- """Plan seam-fix bands along tile boundaries.
435
-
436
- For each pair of adjacent core regions, creates a band centered on the
437
- shared boundary. Bands are denoised in a second pass to smooth seams.
438
-
439
- Args:
440
- tiles: The tile plan (from plan_tiles_linear or plan_tiles_chess).
441
- image_width: Full image width.
442
- image_height: Full image height.
443
- seam_fix_width: Width of the seam-fix band in pixels.
444
- seam_fix_padding: Additional padding around each band for denoise context.
445
-
446
- Returns:
447
- List of ``SeamFixSpec`` for all seam boundaries.
448
- """
449
- if seam_fix_width < 0:
450
- raise ValueError(f"`seam_fix_width` must be non-negative, got {seam_fix_width}.")
451
- if seam_fix_width == 0:
452
- return []
453
- if seam_fix_padding < 0:
454
- raise ValueError(f"`seam_fix_padding` must be non-negative, got {seam_fix_padding}.")
455
-
456
- # Collect unique boundary positions
457
- h_boundaries: set[tuple[int, int, int]] = set() # (y, x_start, x_end)
458
- v_boundaries: set[tuple[int, int, int]] = set() # (x, y_start, y_end)
459
-
460
- for tile in tiles:
461
- # Bottom edge of this tile → horizontal seam
462
- bottom_y = tile.core_y + tile.core_h
463
- if bottom_y < image_height:
464
- h_boundaries.add((bottom_y, tile.core_x, tile.core_x + tile.core_w))
465
-
466
- # Right edge → vertical seam
467
- right_x = tile.core_x + tile.core_w
468
- if right_x < image_width:
469
- v_boundaries.add((right_x, tile.core_y, tile.core_y + tile.core_h))
470
-
471
- bands: list[SeamFixSpec] = []
472
- half_left = seam_fix_width // 2
473
- half_right = seam_fix_width - half_left
474
-
475
- for y, x_start, x_end in sorted(h_boundaries):
476
- band_y = max(0, y - half_left)
477
- band_y2 = min(image_height, y + half_right)
478
- band_h = band_y2 - band_y
479
- band_w = x_end - x_start
480
-
481
- crop_x = max(0, x_start - seam_fix_padding)
482
- crop_y = max(0, band_y - seam_fix_padding)
483
- crop_x2 = min(image_width, x_end + seam_fix_padding)
484
- crop_y2 = min(image_height, band_y2 + seam_fix_padding)
485
-
486
- bands.append(SeamFixSpec(
487
- band_x=x_start, band_y=band_y, band_w=band_w, band_h=band_h,
488
- crop_x=crop_x, crop_y=crop_y,
489
- crop_w=crop_x2 - crop_x, crop_h=crop_y2 - crop_y,
490
- paste_x=x_start - crop_x, paste_y=band_y - crop_y,
491
- orientation="horizontal",
492
- ))
493
-
494
- for x, y_start, y_end in sorted(v_boundaries):
495
- band_x = max(0, x - half_left)
496
- band_x2 = min(image_width, x + half_right)
497
- band_w = band_x2 - band_x
498
- band_h = y_end - y_start
499
-
500
- crop_x = max(0, band_x - seam_fix_padding)
501
- crop_y = max(0, y_start - seam_fix_padding)
502
- crop_x2 = min(image_width, band_x2 + seam_fix_padding)
503
- crop_y2 = min(image_height, y_end + seam_fix_padding)
504
-
505
- bands.append(SeamFixSpec(
506
- band_x=band_x, band_y=y_start, band_w=band_w, band_h=band_h,
507
- crop_x=crop_x, crop_y=crop_y,
508
- crop_w=crop_x2 - crop_x, crop_h=crop_y2 - crop_y,
509
- paste_x=band_x - crop_x, paste_y=y_start - crop_y,
510
- orientation="vertical",
511
- ))
512
-
513
- return bands
514
-
515
-
516
- def extract_band_from_decoded(decoded_image: np.ndarray, band: SeamFixSpec) -> np.ndarray:
517
- """Extract the band region from a decoded seam-fix image."""
518
- return decoded_image[
519
- band.paste_y : band.paste_y + band.band_h,
520
- band.paste_x : band.paste_x + band.band_w,
521
- ]
522
-
523
-
524
- def make_seam_fix_mask(band: SeamFixSpec, mask_blur: int = 8) -> np.ndarray:
525
- """Create a linearly-feathered mask for a seam-fix band.
526
-
527
- The mask is 1.0 at the center of the seam and linearly fades to 0.0
528
- at the edges perpendicular to the seam orientation, so the seam-fix
529
- blends smoothly with the surrounding tile results.
530
-
531
- Args:
532
- band: Seam-fix band specification.
533
- mask_blur: Width of the linear feather ramp in pixels.
534
 
535
- Returns:
536
- Mask of shape (band_h, band_w), float32, values in [0, 1].
537
- """
538
- if mask_blur <= 0:
539
- return np.ones((band.band_h, band.band_w), dtype=np.float32)
540
-
541
- mask = np.ones((band.band_h, band.band_w), dtype=np.float32)
542
-
543
- if band.orientation == "horizontal":
544
- # Fade along height (top/bottom edges)
545
- ramp = min(mask_blur, band.band_h // 2)
546
- if ramp > 0:
547
- top_ramp = np.linspace(0.0, 1.0, ramp, dtype=np.float32)
548
- mask[:ramp, :] = top_ramp[:, np.newaxis]
549
- bottom_ramp = np.linspace(1.0, 0.0, ramp, dtype=np.float32)
550
- mask[-ramp:, :] = bottom_ramp[:, np.newaxis]
551
- else:
552
- # Fade along width (left/right edges)
553
- ramp = min(mask_blur, band.band_w // 2)
554
- if ramp > 0:
555
- left_ramp = np.linspace(0.0, 1.0, ramp, dtype=np.float32)
556
- mask[:, :ramp] = left_ramp[np.newaxis, :]
557
- right_ramp = np.linspace(1.0, 0.0, ramp, dtype=np.float32)
558
- mask[:, -ramp:] = right_ramp[np.newaxis, :]
559
-
560
- return mask
561
-
562
-
563
- def paste_seam_fix_band(
564
- canvas: np.ndarray,
565
- band_image: np.ndarray,
566
- band: SeamFixSpec,
567
- mask_blur: int = 8,
568
- ) -> None:
569
- """Paste a seam-fix band into the canvas with feathered blending.
570
-
571
- Args:
572
- canvas: Output canvas, shape (H, W, C), float32. Modified in-place.
573
- band_image: Decoded band pixels, shape (band_h, band_w, C), float32.
574
- band: Seam-fix band specification.
575
- mask_blur: Feathering width.
576
- """
577
- mask = make_seam_fix_mask(band, mask_blur)
578
-
579
- y1, y2 = band.band_y, band.band_y + band.band_h
580
- x1, x2 = band.band_x, band.band_x + band.band_w
581
 
582
- existing = canvas[y1:y2, x1:x2]
583
- canvas[y1:y2, x1:x2] = existing * (1 - mask[:, :, np.newaxis]) + band_image * mask[:, :, np.newaxis]
584
-
585
-
586
- # =============================================================================
587
- # Latent-space tile planning for MultiDiffusion
588
- # =============================================================================
589
 
590
 
591
  @dataclass
592
  class LatentTileSpec:
593
- """Tile specification in latent space for MultiDiffusion.
594
 
595
  Attributes:
596
  y: Top edge in latent pixels.
@@ -605,26 +32,7 @@ class LatentTileSpec:
605
  w: int
606
 
607
 
608
- def plan_latent_tiles(
609
- latent_h: int,
610
- latent_w: int,
611
- tile_size: int = 64,
612
- overlap: int = 8,
613
- ) -> list[LatentTileSpec]:
614
- """Plan overlapping tiles in latent space for MultiDiffusion.
615
-
616
- Tiles overlap by ``overlap`` latent pixels. The stride is
617
- ``tile_size - overlap``. Edge tiles are clamped to the latent bounds.
618
-
619
- Args:
620
- latent_h: Height of the full latent tensor.
621
- latent_w: Width of the full latent tensor.
622
- tile_size: Tile size in latent pixels (e.g., 64 = 512px at scale 8).
623
- overlap: Overlap in latent pixels (e.g., 8 = 64px at scale 8).
624
-
625
- Returns:
626
- List of ``LatentTileSpec``.
627
- """
628
  if tile_size <= 0:
629
  raise ValueError(f"`tile_size` must be positive, got {tile_size}.")
630
  if overlap < 0:
@@ -635,13 +43,26 @@ def plan_latent_tiles(
635
  f"Got overlap={overlap}, tile_size={tile_size}."
636
  )
637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  stride = tile_size - overlap
639
  tiles: list[LatentTileSpec] = []
640
 
641
  y = 0
642
  while y < latent_h:
643
  h = min(tile_size, latent_h - y)
644
- # If remaining height is less than tile_size, shift back to get a full tile
645
  if h < tile_size and y > 0:
646
  y = max(0, latent_h - tile_size)
647
  h = latent_h - y
@@ -666,6 +87,50 @@ def plan_latent_tiles(
666
  return tiles
667
 
668
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
  # ============================================================
670
  # input
671
  # ============================================================
@@ -684,6 +149,8 @@ def plan_latent_tiles(
684
  # See the License for the specific language governing permissions and
685
  # limitations under the License.
686
 
 
 
687
  import PIL.Image
688
  import torch
689
 
@@ -699,43 +166,29 @@ logger = logging.get_logger(__name__)
699
  class UltimateSDUpscaleTextEncoderStep(StableDiffusionXLTextEncoderStep):
700
  """SDXL text encoder step that applies guidance scale before encoding.
701
 
702
- StableDiffusionXLTextEncoderStep decides whether to prepare unconditional
703
- embeddings based on `components.guider.num_conditions`. This depends on the
704
- current `components.guider.guidance_scale` value.
705
 
706
- In tiled upscaling, users may call the same pipeline repeatedly with
707
- different `guidance_scale` values. Without syncing the guider scale before
708
- text encoding, a previous run can leave the guider in a stale state and
709
- cause missing negative embeddings on the next run.
710
-
711
- Also applies a sensible default negative prompt for upscaling when the user
712
- does not provide one, controlled by ``use_default_negative``.
713
  """
714
 
715
  DEFAULT_NEGATIVE_PROMPT = "blurry, low quality, artifacts, noise, jpeg compression"
716
 
717
  @property
718
  def inputs(self) -> list[InputParam]:
719
- # Keep all SDXL text-encoder inputs and add guidance_scale override.
720
  return super().inputs + [
721
  InputParam(
722
  "guidance_scale",
723
  type_hint=float,
724
  default=7.5,
725
- description=(
726
- "Classifier-Free Guidance scale used to configure the guider "
727
- "before prompt encoding."
728
- ),
729
  ),
730
  InputParam(
731
  "use_default_negative",
732
  type_hint=bool,
733
  default=True,
734
- description=(
735
- "When True and negative_prompt is None or empty, apply a default "
736
- "negative prompt optimized for upscaling: "
737
- "'blurry, low quality, artifacts, noise, jpeg compression'."
738
- ),
739
  ),
740
  ]
741
 
@@ -747,7 +200,6 @@ class UltimateSDUpscaleTextEncoderStep(StableDiffusionXLTextEncoderStep):
747
  if hasattr(components, "guider") and components.guider is not None:
748
  components.guider.guidance_scale = guidance_scale
749
 
750
- # Apply default negative prompt if user didn't provide one
751
  use_default_negative = getattr(block_state, "use_default_negative", True)
752
  if use_default_negative:
753
  neg = getattr(block_state, "negative_prompt", None)
@@ -759,56 +211,27 @@ class UltimateSDUpscaleTextEncoderStep(StableDiffusionXLTextEncoderStep):
759
 
760
 
761
  class UltimateSDUpscaleUpscaleStep(ModularPipelineBlocks):
762
- """Upscales the input image using Lanczos interpolation.
763
-
764
- This is the first custom step in the tiled upscaling workflow.
765
- It takes an input image and upscale factor, producing an upscaled image
766
- that subsequent tile steps will refine.
767
- """
768
 
769
  @property
770
  def description(self) -> str:
771
- return (
772
- "Upscale step that resizes the input image by a given factor.\n"
773
- "Currently supports Lanczos interpolation. Model-based upscalers "
774
- "can be added in future passes."
775
- )
776
 
777
  @property
778
  def inputs(self) -> list[InputParam]:
779
  return [
780
- InputParam(
781
- "image",
782
- type_hint=PIL.Image.Image,
783
- required=True,
784
- description="The input image to upscale and refine.",
785
- ),
786
- InputParam(
787
- "upscale_factor",
788
- type_hint=float,
789
- default=2.0,
790
- description="Factor by which to upscale the input image.",
791
- ),
792
  ]
793
 
794
  @property
795
  def intermediate_outputs(self) -> list[OutputParam]:
796
  return [
797
- OutputParam(
798
- "upscaled_image",
799
- type_hint=PIL.Image.Image,
800
- description="The upscaled image before tile-based refinement.",
801
- ),
802
- OutputParam(
803
- "upscaled_width",
804
- type_hint=int,
805
- description="Width of the upscaled image.",
806
- ),
807
- OutputParam(
808
- "upscaled_height",
809
- type_hint=int,
810
- description="Height of the upscaled image.",
811
- ),
812
  ]
813
 
814
  @torch.no_grad()
@@ -819,10 +242,7 @@ class UltimateSDUpscaleUpscaleStep(ModularPipelineBlocks):
819
  upscale_factor = block_state.upscale_factor
820
 
821
  if not isinstance(image, PIL.Image.Image):
822
- raise ValueError(
823
- f"Expected `image` to be a PIL.Image.Image, got {type(image)}. "
824
- "Please pass a PIL image to the pipeline."
825
- )
826
 
827
  new_width = int(image.width * upscale_factor)
828
  new_height = int(image.height * upscale_factor)
@@ -831,1107 +251,209 @@ class UltimateSDUpscaleUpscaleStep(ModularPipelineBlocks):
831
  block_state.upscaled_width = new_width
832
  block_state.upscaled_height = new_height
833
 
834
- logger.info(
835
- f"Upscaled image from {image.width}x{image.height} to {new_width}x{new_height} "
836
- f"(factor={upscale_factor})"
837
- )
838
-
839
- self.set_block_state(state, block_state)
840
- return components, state
841
-
842
-
843
- class UltimateSDUpscaleTilePlanStep(ModularPipelineBlocks):
844
- """Plans the tile grid for the upscaled image.
845
-
846
- Generates a list of ``TileSpec`` objects based on the requested tile size
847
- and padding. Supports linear (raster) and chess (checkerboard) traversal.
848
- Optionally plans seam-fix bands along tile boundaries.
849
- """
850
-
851
- @property
852
- def description(self) -> str:
853
- return (
854
- "Tile planning step that generates tile coordinates for the upscaled image.\n"
855
- "Supports 'linear' (raster) and 'chess' (checkerboard) traversal.\n"
856
- "Optionally plans seam-fix bands along tile boundaries."
857
- )
858
-
859
- @property
860
- def inputs(self) -> list[InputParam]:
861
- return [
862
- InputParam("upscaled_width", type_hint=int, required=True,
863
- description="Width of the upscaled image."),
864
- InputParam("upscaled_height", type_hint=int, required=True,
865
- description="Height of the upscaled image."),
866
- InputParam("tile_size", type_hint=int, default=2048,
867
- description="Base tile size in pixels. Default 2048 processes most images "
868
- "in a single pass (seamless). Set to 512 for tiled mode on very large images."),
869
- InputParam("tile_padding", type_hint=int, default=32,
870
- description="Number of overlap pixels on each side of a tile. Only relevant when tiling."),
871
- InputParam("traversal_mode", type_hint=str, default="linear",
872
- description="Tile traversal order: 'linear' or 'chess'."),
873
- InputParam("seam_fix_width", type_hint=int, default=0,
874
- description="Width of seam-fix bands in pixels. 0 disables seam fixing."),
875
- InputParam("seam_fix_padding", type_hint=int, default=16,
876
- description="Extra padding around seam-fix bands for denoise context."),
877
- InputParam("seam_fix_mask_blur", type_hint=int, default=8,
878
- description="Feathering width for seam-fix blending masks."),
879
- ]
880
-
881
- @property
882
- def intermediate_outputs(self) -> list[OutputParam]:
883
- return [
884
- OutputParam("tile_plan", type_hint=list,
885
- description="List of TileSpec defining the tile grid."),
886
- OutputParam("num_tiles", type_hint=int,
887
- description="Total number of tiles in the plan."),
888
- OutputParam("seam_fix_plan", type_hint=list,
889
- description="List of SeamFixSpec for seam-fix bands (empty if disabled)."),
890
- OutputParam("seam_fix_mask_blur", type_hint=int,
891
- description="Feathering width for seam-fix blending."),
892
- ]
893
-
894
- @torch.no_grad()
895
- def __call__(self, components, state: PipelineState) -> PipelineState:
896
- block_state = self.get_block_state(state)
897
-
898
- tile_size = block_state.tile_size
899
- tile_padding = block_state.tile_padding
900
- traversal_mode = block_state.traversal_mode
901
-
902
- if traversal_mode not in ("linear", "chess"):
903
- raise ValueError(
904
- f"Unsupported traversal_mode '{traversal_mode}'. "
905
- "Supported modes: 'linear', 'chess'."
906
- )
907
-
908
- validate_tile_params(tile_size, tile_padding)
909
-
910
- if traversal_mode == "chess":
911
- tile_plan = plan_tiles_chess(
912
- image_width=block_state.upscaled_width,
913
- image_height=block_state.upscaled_height,
914
- tile_size=tile_size,
915
- tile_padding=tile_padding,
916
- )
917
- else:
918
- tile_plan = plan_tiles_linear(
919
- image_width=block_state.upscaled_width,
920
- image_height=block_state.upscaled_height,
921
- tile_size=tile_size,
922
- tile_padding=tile_padding,
923
- )
924
-
925
- # Validate and plan seam-fix bands if enabled
926
- seam_fix_width = block_state.seam_fix_width
927
- seam_fix_padding = block_state.seam_fix_padding
928
- seam_fix_mask_blur = block_state.seam_fix_mask_blur
929
-
930
- if seam_fix_width < 0:
931
- raise ValueError(f"`seam_fix_width` must be non-negative, got {seam_fix_width}.")
932
- if seam_fix_padding < 0:
933
- raise ValueError(f"`seam_fix_padding` must be non-negative, got {seam_fix_padding}.")
934
- if seam_fix_mask_blur < 0:
935
- raise ValueError(f"`seam_fix_mask_blur` must be non-negative, got {seam_fix_mask_blur}.")
936
-
937
- if seam_fix_width > 0:
938
- seam_fix_plan = plan_seam_fix_bands(
939
- tiles=tile_plan,
940
- image_width=block_state.upscaled_width,
941
- image_height=block_state.upscaled_height,
942
- seam_fix_width=seam_fix_width,
943
- seam_fix_padding=seam_fix_padding,
944
- )
945
- else:
946
- seam_fix_plan = []
947
-
948
- block_state.tile_plan = tile_plan
949
- block_state.num_tiles = len(tile_plan)
950
- block_state.seam_fix_plan = seam_fix_plan
951
- block_state.seam_fix_mask_blur = seam_fix_mask_blur
952
-
953
- logger.info(
954
- f"Planned {len(tile_plan)} tiles "
955
- f"(tile_size={tile_size}, padding={tile_padding}, traversal={traversal_mode})"
956
- + (f", {len(seam_fix_plan)} seam-fix bands" if seam_fix_plan else "")
957
- )
958
 
959
  self.set_block_state(state, block_state)
960
  return components, state
961
 
962
 
963
- # ============================================================
964
- # denoise
965
- # ============================================================
966
-
967
- # Copyright 2025 The HuggingFace Team. All rights reserved.
968
- #
969
- # Licensed under the Apache License, Version 2.0 (the "License");
970
- # you may not use this file except in compliance with the License.
971
- # You may obtain a copy of the License at
972
- #
973
- # http://www.apache.org/licenses/LICENSE-2.0
974
- #
975
- # Unless required by applicable law or agreed to in writing, software
976
- # distributed under the License is distributed on an "AS IS" BASIS,
977
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
978
- # See the License for the specific language governing permissions and
979
- # limitations under the License.
980
-
981
- """Tiled upscaling denoise steps for Modular SDXL Upscale.
982
-
983
- Architecture follows the ``LoopSequentialPipelineBlocks`` pattern used by the
984
- SDXL denoising loop. ``UltimateSDUpscaleTileLoopStep`` is the loop wrapper
985
- (iterates over *tiles*); its sub-blocks are leaf blocks that handle one tile
986
- per call:
987
-
988
- TilePrepareStep – crop, VAE encode, prepare latents, tile-aware add_cond
989
- TileDenoiserStep – full denoising loop (wraps ``StableDiffusionXLDenoiseStep``)
990
- TilePostProcessStep – decode latents, extract core, paste into canvas
991
-
992
- SDXL blocks are reused via their public interface by creating temporary
993
- ``PipelineState`` objects, NOT by calling private helpers.
994
- """
995
-
996
- import math
997
- import time
998
-
999
- import numpy as np
1000
- import PIL.Image
1001
- import torch
1002
- from tqdm.auto import tqdm
1003
-
1004
- from diffusers.configuration_utils import FrozenDict
1005
- from diffusers.guiders import ClassifierFreeGuidance
1006
- from diffusers.image_processor import VaeImageProcessor
1007
- from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
1008
- from diffusers.schedulers import DPMSolverMultistepScheduler, EulerDiscreteScheduler
1009
- from diffusers.utils import logging
1010
- from diffusers.utils.torch_utils import randn_tensor
1011
- from diffusers.modular_pipelines.modular_pipeline import (
1012
- BlockState,
1013
- LoopSequentialPipelineBlocks,
1014
- ModularPipelineBlocks,
1015
- PipelineState,
1016
- )
1017
- from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
1018
- from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import (
1019
- StableDiffusionXLControlNetInputStep,
1020
- StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
1021
- StableDiffusionXLImg2ImgPrepareLatentsStep,
1022
- prepare_latents_img2img,
1023
- )
1024
- from diffusers.modular_pipelines.stable_diffusion_xl.decoders import StableDiffusionXLDecodeStep
1025
- from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep
1026
- from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLVaeEncoderStep
1027
-
1028
-
1029
- logger = logging.get_logger(__name__)
1030
-
1031
-
1032
- # ---------------------------------------------------------------------------
1033
- # Helper: populate a PipelineState from a dict
1034
- # ---------------------------------------------------------------------------
1035
-
1036
- def _make_state(values: dict, kwargs_type_map: dict | None = None) -> PipelineState:
1037
- """Create a PipelineState and set values, optionally with kwargs_type."""
1038
- state = PipelineState()
1039
- kwargs_type_map = kwargs_type_map or {}
1040
- for k, v in values.items():
1041
- state.set(k, v, kwargs_type_map.get(k))
1042
- return state
1043
-
1044
-
1045
- def _to_pil_rgb_image(image) -> PIL.Image.Image:
1046
- """Convert a tensor/ndarray/PIL image to a RGB PIL image."""
1047
- if isinstance(image, PIL.Image.Image):
1048
- return image.convert("RGB")
1049
-
1050
- if torch.is_tensor(image):
1051
- tensor = image.detach().cpu()
1052
- if tensor.ndim == 4:
1053
- if tensor.shape[0] != 1:
1054
- raise ValueError(
1055
- f"`control_image` tensor batch must be 1 for tiled upscaling, got shape {tuple(tensor.shape)}."
1056
- )
1057
- tensor = tensor[0]
1058
- if tensor.ndim == 3 and tensor.shape[0] in (1, 3, 4) and tensor.shape[-1] not in (1, 3, 4):
1059
- tensor = tensor.permute(1, 2, 0)
1060
- image = tensor.numpy()
1061
-
1062
- if isinstance(image, np.ndarray):
1063
- array = image
1064
- if array.ndim == 4:
1065
- if array.shape[0] != 1:
1066
- raise ValueError(
1067
- f"`control_image` ndarray batch must be 1 for tiled upscaling, got shape {array.shape}."
1068
- )
1069
- array = array[0]
1070
- if array.ndim == 3 and array.shape[0] in (1, 3, 4) and array.shape[-1] not in (1, 3, 4):
1071
- array = np.transpose(array, (1, 2, 0))
1072
- if array.ndim == 2:
1073
- array = np.stack([array] * 3, axis=-1)
1074
- if array.ndim != 3:
1075
- raise ValueError(f"`control_image` must have 2 or 3 dimensions, got shape {array.shape}.")
1076
- if array.shape[-1] == 1:
1077
- array = np.repeat(array, 3, axis=-1)
1078
- if array.shape[-1] == 4:
1079
- array = array[..., :3]
1080
- if array.shape[-1] != 3:
1081
- raise ValueError(f"`control_image` channel dimension must be 1/3/4, got shape {array.shape}.")
1082
- if array.dtype != np.uint8:
1083
- array = np.asarray(array, dtype=np.float32)
1084
- max_val = float(np.max(array)) if array.size > 0 else 1.0
1085
- if max_val <= 1.0:
1086
- array = (np.clip(array, 0.0, 1.0) * 255.0).astype(np.uint8)
1087
- else:
1088
- array = np.clip(array, 0.0, 255.0).astype(np.uint8)
1089
- return PIL.Image.fromarray(array).convert("RGB")
1090
-
1091
- raise ValueError(
1092
- f"Unsupported `control_image` type {type(image)}. Expected PIL.Image, torch.Tensor, or numpy.ndarray."
1093
- )
1094
-
1095
-
1096
- # ---------------------------------------------------------------------------
1097
- # Scheduler swap helper (Feature 5)
1098
- # ---------------------------------------------------------------------------
1099
-
1100
- _SCHEDULER_ALIASES = {
1101
- "euler": "EulerDiscreteScheduler",
1102
- "euler discrete": "EulerDiscreteScheduler",
1103
- "eulerdiscretescheduler": "EulerDiscreteScheduler",
1104
- "dpm++ 2m": "DPMSolverMultistepScheduler",
1105
- "dpmsolvermultistepscheduler": "DPMSolverMultistepScheduler",
1106
- "dpm++ 2m karras": "DPMSolverMultistepScheduler+karras",
1107
- }
1108
-
1109
-
1110
- def _swap_scheduler(components, scheduler_name: str):
1111
- """Swap the scheduler on ``components`` given a human-readable name.
1112
-
1113
- Supported names (case-insensitive):
1114
- - ``"Euler"`` / ``"EulerDiscreteScheduler"``
1115
- - ``"DPM++ 2M"`` / ``"DPMSolverMultistepScheduler"``
1116
- - ``"DPM++ 2M Karras"`` (DPMSolverMultistep with Karras sigmas)
1117
-
1118
- If the requested scheduler is already active, this is a no-op.
1119
- """
1120
- key = scheduler_name.strip().lower()
1121
- resolved = _SCHEDULER_ALIASES.get(key, key)
1122
-
1123
- use_karras = resolved.endswith("+karras")
1124
- if use_karras:
1125
- resolved = resolved.replace("+karras", "")
1126
-
1127
- current = type(components.scheduler).__name__
1128
-
1129
- if resolved == "EulerDiscreteScheduler":
1130
- if current != "EulerDiscreteScheduler":
1131
- components.scheduler = EulerDiscreteScheduler.from_config(components.scheduler.config)
1132
- logger.info("Swapped scheduler to EulerDiscreteScheduler")
1133
- elif resolved == "DPMSolverMultistepScheduler":
1134
- if current != "DPMSolverMultistepScheduler" or (
1135
- use_karras and not getattr(components.scheduler.config, "use_karras_sigmas", False)
1136
- ):
1137
- extra_kwargs = {}
1138
- if use_karras:
1139
- extra_kwargs["use_karras_sigmas"] = True
1140
- components.scheduler = DPMSolverMultistepScheduler.from_config(
1141
- components.scheduler.config, **extra_kwargs
1142
- )
1143
- logger.info(f"Swapped scheduler to DPMSolverMultistepScheduler (karras={use_karras})")
1144
- else:
1145
- logger.warning(
1146
- f"Unknown scheduler_name '{scheduler_name}'. Keeping current scheduler "
1147
- f"({current}). Supported: 'Euler', 'DPM++ 2M', 'DPM++ 2M Karras'."
1148
- )
1149
-
1150
-
1151
- # ---------------------------------------------------------------------------
1152
- # Auto-strength helper (Feature 2)
1153
- # ---------------------------------------------------------------------------
1154
-
1155
- def _compute_auto_strength(upscale_factor: float, pass_index: int, num_passes: int) -> float:
1156
- """Return the auto-scaled denoise strength for a given pass.
1157
-
1158
- Rules:
1159
- - Single-pass 2x: 0.3
1160
- - Single-pass 4x: 0.15
1161
- - Progressive passes: first pass=0.3, subsequent passes=0.2
1162
- """
1163
- if num_passes > 1:
1164
- return 0.3 if pass_index == 0 else 0.2
1165
- # Single pass
1166
- if upscale_factor <= 2.0:
1167
- return 0.3
1168
- elif upscale_factor <= 4.0:
1169
- return 0.15
1170
- else:
1171
- return 0.1
1172
-
1173
-
1174
- # ---------------------------------------------------------------------------
1175
- # Loop sub-block 1: Prepare (crop + encode + timesteps + latents + add_cond)
1176
- # ---------------------------------------------------------------------------
1177
-
1178
- class UltimateSDUpscaleTilePrepareStep(ModularPipelineBlocks):
1179
- """Loop sub-block that prepares one tile for denoising.
1180
-
1181
- For each tile it:
1182
- 1. Crops the padded region from the upscaled image.
1183
- 2. Calls ``StableDiffusionXLVaeEncoderStep`` to encode to latents.
1184
- 3. Resets the scheduler step index (reuses timesteps from the outer
1185
- set_timesteps block — does NOT re-run set_timesteps to avoid
1186
- double-applying strength).
1187
- 4. Calls ``StableDiffusionXLImg2ImgPrepareLatentsStep``.
1188
- 5. Calls ``StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep``
1189
- with tile-aware ``crops_coords_top_left`` and ``target_size``.
1190
-
1191
- All SDXL blocks are reused via their public ``__call__`` interface.
1192
- """
1193
-
1194
- model_name = "stable-diffusion-xl"
1195
-
1196
- def __init__(self):
1197
- super().__init__()
1198
- # Store SDXL blocks as attributes (NOT in sub_blocks → remains a leaf)
1199
- self._vae_encoder = StableDiffusionXLVaeEncoderStep()
1200
- self._prepare_latents = StableDiffusionXLImg2ImgPrepareLatentsStep()
1201
- self._prepare_add_cond = StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep()
1202
- self._prepare_controlnet = StableDiffusionXLControlNetInputStep()
1203
-
1204
- @property
1205
- def description(self) -> str:
1206
- return (
1207
- "Loop sub-block: crops a tile, encodes to latents, resets scheduler "
1208
- "timesteps, prepares latents, and computes tile-aware additional conditioning."
1209
- )
1210
-
1211
- @property
1212
- def expected_components(self) -> list[ComponentSpec]:
1213
- return [
1214
- ComponentSpec("vae", AutoencoderKL),
1215
- ComponentSpec(
1216
- "image_processor",
1217
- VaeImageProcessor,
1218
- config=FrozenDict({"vae_scale_factor": 8}),
1219
- default_creation_method="from_config",
1220
- ),
1221
- ComponentSpec("scheduler", EulerDiscreteScheduler),
1222
- ComponentSpec("unet", UNet2DConditionModel),
1223
- ComponentSpec(
1224
- "guider",
1225
- ClassifierFreeGuidance,
1226
- config=FrozenDict({"guidance_scale": 7.5}),
1227
- default_creation_method="from_config",
1228
- ),
1229
- ComponentSpec(
1230
- "control_image_processor",
1231
- VaeImageProcessor,
1232
- config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}),
1233
- default_creation_method="from_config",
1234
- ),
1235
- ]
1236
-
1237
- @property
1238
- def expected_configs(self) -> list[ConfigSpec]:
1239
- return [ConfigSpec("requires_aesthetics_score", False)]
1240
-
1241
- @property
1242
- def inputs(self) -> list[InputParam]:
1243
- return [
1244
- InputParam("upscaled_image", type_hint=PIL.Image.Image, required=True),
1245
- InputParam("upscaled_height", type_hint=int, required=True),
1246
- InputParam("upscaled_width", type_hint=int, required=True),
1247
- InputParam("generator"),
1248
- InputParam("batch_size", type_hint=int, required=True),
1249
- InputParam("num_images_per_prompt", type_hint=int, default=1),
1250
- InputParam("dtype", type_hint=torch.dtype, required=True),
1251
- InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"),
1252
- InputParam("num_inference_steps", type_hint=int, default=50),
1253
- InputParam("strength", type_hint=float, default=0.3),
1254
- InputParam("timesteps", type_hint=torch.Tensor, required=True),
1255
- InputParam("latent_timestep", type_hint=torch.Tensor, required=True),
1256
- InputParam("denoising_start"),
1257
- InputParam("denoising_end"),
1258
- InputParam("use_controlnet", type_hint=bool, default=False),
1259
- InputParam("control_image_processed"),
1260
- InputParam("control_guidance_start", default=0.0),
1261
- InputParam("control_guidance_end", default=1.0),
1262
- InputParam("controlnet_conditioning_scale", default=1.0),
1263
- InputParam("guess_mode", default=False),
1264
- ]
1265
-
1266
- @property
1267
- def intermediate_outputs(self) -> list[OutputParam]:
1268
- return [
1269
- OutputParam("latents", type_hint=torch.Tensor),
1270
- OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
1271
- OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
1272
- OutputParam("timestep_cond", type_hint=torch.Tensor),
1273
- OutputParam("controlnet_cond", type_hint=torch.Tensor),
1274
- OutputParam("conditioning_scale"),
1275
- OutputParam("controlnet_keep", type_hint=list[float]),
1276
- OutputParam("guess_mode", type_hint=bool),
1277
- ]
1278
-
1279
- @torch.no_grad()
1280
- def __call__(self, components, block_state: BlockState, tile_idx: int, tile: TileSpec):
1281
- # --- 1. Crop tile ---
1282
- tile_image = crop_tile(block_state.upscaled_image, tile)
1283
-
1284
- # --- 2. VAE encode tile ---
1285
- enc_state = _make_state({
1286
- "image": tile_image,
1287
- "height": tile.crop_h,
1288
- "width": tile.crop_w,
1289
- "generator": block_state.generator,
1290
- "dtype": block_state.dtype,
1291
- "preprocess_kwargs": None,
1292
- })
1293
- components, enc_state = self._vae_encoder(components, enc_state)
1294
- image_latents = enc_state.get("image_latents")
1295
-
1296
- # --- 3. Reset scheduler step state for this tile ---
1297
- # The outer set_timesteps block already computed the correct timesteps
1298
- # and num_inference_steps (with strength applied). We must NOT re-run
1299
- # set_timesteps here — that would double-apply strength and produce
1300
- # 0 denoising steps. Instead, reset the scheduler's mutable step index
1301
- # so it can iterate the same schedule again for this tile.
1302
- scheduler = components.scheduler
1303
- latent_timestep = block_state.latent_timestep
1304
-
1305
- # Only reset _step_index (progress counter). Do NOT touch _begin_index —
1306
- # it holds the correct start position computed by the outer set_timesteps
1307
- # step (e.g., step 14 for strength=0.3 with 20 steps). Resetting it to 0
1308
- # would make the scheduler use sigmas for full noise (timestep ~999) when
1309
- # the latents only have partial noise (timestep ~250), producing garbage.
1310
- if hasattr(scheduler, "_step_index"):
1311
- scheduler._step_index = None
1312
- if hasattr(scheduler, "is_scale_input_called"):
1313
- scheduler.is_scale_input_called = False
1314
-
1315
- # --- 4. Prepare latents ---
1316
- # Build clean init latents first (no random noise yet), then add tile noise.
1317
- # Using a global noise map keeps noise spatially consistent across tiles and
1318
- # greatly reduces cross-tile drift/artifacts.
1319
- clean_latents = prepare_latents_img2img(
1320
- components.vae,
1321
- components.scheduler,
1322
- image_latents,
1323
- latent_timestep,
1324
- block_state.batch_size,
1325
- block_state.num_images_per_prompt,
1326
- block_state.dtype,
1327
- image_latents.device,
1328
- generator=None,
1329
- add_noise=False,
1330
- )
1331
-
1332
- latent_h, latent_w = clean_latents.shape[-2], clean_latents.shape[-1]
1333
- global_noise_map = getattr(block_state, "global_noise_map", None)
1334
- if global_noise_map is not None:
1335
- vae_scale_factor = int(getattr(block_state, "global_noise_scale", 8))
1336
- y0 = max(0, tile.crop_y // vae_scale_factor)
1337
- x0 = max(0, tile.crop_x // vae_scale_factor)
1338
- max_y0 = max(0, global_noise_map.shape[-2] - latent_h)
1339
- max_x0 = max(0, global_noise_map.shape[-1] - latent_w)
1340
- y0 = min(y0, max_y0)
1341
- x0 = min(x0, max_x0)
1342
- tile_noise = global_noise_map[:, :, y0 : y0 + latent_h, x0 : x0 + latent_w]
1343
-
1344
- # Defensive fallback if latent shape and crop math ever diverge.
1345
- if tile_noise.shape != clean_latents.shape:
1346
- tile_noise = randn_tensor(
1347
- clean_latents.shape,
1348
- generator=block_state.generator,
1349
- device=clean_latents.device,
1350
- dtype=clean_latents.dtype,
1351
- )
1352
- else:
1353
- tile_noise = randn_tensor(
1354
- clean_latents.shape,
1355
- generator=block_state.generator,
1356
- device=clean_latents.device,
1357
- dtype=clean_latents.dtype,
1358
- )
1359
-
1360
- pre_noised_latents = components.scheduler.add_noise(clean_latents, tile_noise, latent_timestep)
1361
-
1362
- lat_state = _make_state({
1363
- "image_latents": image_latents,
1364
- "latent_timestep": latent_timestep,
1365
- "batch_size": block_state.batch_size,
1366
- "num_images_per_prompt": block_state.num_images_per_prompt,
1367
- "dtype": block_state.dtype,
1368
- "generator": block_state.generator,
1369
- "latents": pre_noised_latents,
1370
- "denoising_start": getattr(block_state, "denoising_start", None),
1371
- })
1372
- components, lat_state = self._prepare_latents(components, lat_state)
1373
-
1374
- # --- 5. Prepare additional conditioning (tile-aware) ---
1375
- # crops_coords_top_left tells SDXL where this tile sits in the canvas
1376
- # target_size is the tile's pixel dimensions
1377
- # original_size is the full upscaled image dimensions
1378
- cond_state = _make_state({
1379
- "original_size": (block_state.upscaled_height, block_state.upscaled_width),
1380
- "target_size": (tile.crop_h, tile.crop_w),
1381
- "crops_coords_top_left": (tile.crop_y, tile.crop_x),
1382
- "negative_original_size": None,
1383
- "negative_target_size": None,
1384
- "negative_crops_coords_top_left": (0, 0),
1385
- "num_images_per_prompt": block_state.num_images_per_prompt,
1386
- "aesthetic_score": 6.0,
1387
- "negative_aesthetic_score": 2.0,
1388
- "latents": lat_state.get("latents"),
1389
- "pooled_prompt_embeds": block_state.pooled_prompt_embeds,
1390
- "batch_size": block_state.batch_size,
1391
- })
1392
- components, cond_state = self._prepare_add_cond(components, cond_state)
1393
-
1394
- # --- Write results to block_state ---
1395
- # timesteps/num_inference_steps/latent_timestep are from the outer
1396
- # set_timesteps step (already in block_state), no need to overwrite.
1397
- block_state.latents = lat_state.get("latents")
1398
- block_state.add_time_ids = cond_state.get("add_time_ids")
1399
- block_state.negative_add_time_ids = cond_state.get("negative_add_time_ids")
1400
- block_state.timestep_cond = cond_state.get("timestep_cond")
1401
- if getattr(block_state, "use_controlnet", False):
1402
- control_tile = crop_tile(block_state.control_image_processed, tile)
1403
- control_state = _make_state({
1404
- "control_image": control_tile,
1405
- "control_guidance_start": getattr(block_state, "control_guidance_start", 0.0),
1406
- "control_guidance_end": getattr(block_state, "control_guidance_end", 1.0),
1407
- "controlnet_conditioning_scale": getattr(block_state, "controlnet_conditioning_scale", 1.0),
1408
- "guess_mode": getattr(block_state, "guess_mode", False),
1409
- "num_images_per_prompt": block_state.num_images_per_prompt,
1410
- "latents": block_state.latents,
1411
- "batch_size": block_state.batch_size,
1412
- "timesteps": block_state.timesteps,
1413
- "crops_coords": None,
1414
- })
1415
- components, control_state = self._prepare_controlnet(components, control_state)
1416
- block_state.controlnet_cond = control_state.get("controlnet_cond")
1417
- block_state.conditioning_scale = control_state.get("conditioning_scale")
1418
- block_state.controlnet_keep = control_state.get("controlnet_keep")
1419
- block_state.guess_mode = control_state.get("guess_mode")
1420
- else:
1421
- block_state.controlnet_cond = None
1422
- block_state.conditioning_scale = None
1423
- block_state.controlnet_keep = None
1424
-
1425
- return components, block_state
1426
-
1427
-
1428
- # ---------------------------------------------------------------------------
1429
- # Loop sub-block 2: Denoise
1430
- # ---------------------------------------------------------------------------
1431
-
1432
- class UltimateSDUpscaleTileDenoiserStep(ModularPipelineBlocks):
1433
- """Loop sub-block that runs the full denoising loop for one tile.
1434
-
1435
- Wraps ``StableDiffusionXLDenoiseStep`` (itself a
1436
- ``LoopSequentialPipelineBlocks`` over timesteps). Stored as an attribute,
1437
- not in ``sub_blocks``, so this block remains a leaf.
1438
- """
1439
-
1440
- model_name = "stable-diffusion-xl"
1441
-
1442
- def __init__(self):
1443
- super().__init__()
1444
- self._denoise = StableDiffusionXLDenoiseStep()
1445
- self._controlnet_denoise = StableDiffusionXLControlNetDenoiseStep()
1446
-
1447
- @property
1448
- def description(self) -> str:
1449
- return (
1450
- "Loop sub-block: runs the SDXL denoising loop for one tile, "
1451
- "with optional ControlNet conditioning."
1452
- )
1453
-
1454
- @property
1455
- def expected_components(self) -> list[ComponentSpec]:
1456
- return [
1457
- ComponentSpec("unet", UNet2DConditionModel),
1458
- ComponentSpec("scheduler", EulerDiscreteScheduler),
1459
- ComponentSpec("controlnet", ControlNetModel),
1460
- ComponentSpec(
1461
- "guider",
1462
- ClassifierFreeGuidance,
1463
- config=FrozenDict({"guidance_scale": 7.5}),
1464
- default_creation_method="from_config",
1465
- ),
1466
- ]
1467
-
1468
- @property
1469
- def inputs(self) -> list[InputParam]:
1470
- return [
1471
- InputParam("latents", type_hint=torch.Tensor, required=True),
1472
- InputParam("timesteps", type_hint=torch.Tensor, required=True),
1473
- InputParam("num_inference_steps", type_hint=int, required=True),
1474
- # Denoiser input fields (kwargs_type must match text encoder outputs)
1475
- InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"),
1476
- InputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
1477
- InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"),
1478
- InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
1479
- InputParam("add_time_ids", type_hint=torch.Tensor, required=True, kwargs_type="denoiser_input_fields"),
1480
- InputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
1481
- InputParam("timestep_cond", type_hint=torch.Tensor),
1482
- InputParam("eta", type_hint=float, default=0.0),
1483
- InputParam("generator"),
1484
- InputParam("use_controlnet", type_hint=bool, default=False),
1485
- InputParam("controlnet_cond", type_hint=torch.Tensor),
1486
- InputParam("conditioning_scale"),
1487
- InputParam("controlnet_keep", type_hint=list[float]),
1488
- InputParam("guess_mode", type_hint=bool, default=False),
1489
- ]
1490
-
1491
- @property
1492
- def intermediate_outputs(self) -> list[OutputParam]:
1493
- return [
1494
- OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents."),
1495
- ]
1496
-
1497
- @torch.no_grad()
1498
- def __call__(self, components, block_state: BlockState, tile_idx: int, tile: TileSpec):
1499
- # Build a PipelineState with all the data the SDXL denoise step needs
1500
- denoiser_fields = {
1501
- "prompt_embeds": block_state.prompt_embeds,
1502
- "negative_prompt_embeds": getattr(block_state, "negative_prompt_embeds", None),
1503
- "pooled_prompt_embeds": block_state.pooled_prompt_embeds,
1504
- "negative_pooled_prompt_embeds": getattr(block_state, "negative_pooled_prompt_embeds", None),
1505
- "add_time_ids": block_state.add_time_ids,
1506
- "negative_add_time_ids": getattr(block_state, "negative_add_time_ids", None),
1507
- }
1508
- # Add optional fields
1509
- ip_embeds = getattr(block_state, "ip_adapter_embeds", None)
1510
- neg_ip_embeds = getattr(block_state, "negative_ip_adapter_embeds", None)
1511
- if ip_embeds is not None:
1512
- denoiser_fields["ip_adapter_embeds"] = ip_embeds
1513
- if neg_ip_embeds is not None:
1514
- denoiser_fields["negative_ip_adapter_embeds"] = neg_ip_embeds
1515
-
1516
- kwargs_type_map = {k: "denoiser_input_fields" for k in denoiser_fields}
1517
-
1518
- all_values = {
1519
- **denoiser_fields,
1520
- "latents": block_state.latents,
1521
- "timesteps": block_state.timesteps,
1522
- "num_inference_steps": block_state.num_inference_steps,
1523
- "timestep_cond": getattr(block_state, "timestep_cond", None),
1524
- "eta": getattr(block_state, "eta", 0.0),
1525
- "generator": getattr(block_state, "generator", None),
1526
- }
1527
- use_controlnet = bool(getattr(block_state, "use_controlnet", False))
1528
- if use_controlnet:
1529
- all_values.update(
1530
- {
1531
- "controlnet_cond": block_state.controlnet_cond,
1532
- "conditioning_scale": block_state.conditioning_scale,
1533
- "guess_mode": getattr(block_state, "guess_mode", False),
1534
- "controlnet_keep": block_state.controlnet_keep,
1535
- "controlnet_kwargs": getattr(block_state, "controlnet_kwargs", {}),
1536
- }
1537
- )
1538
-
1539
- denoise_state = _make_state(all_values, kwargs_type_map)
1540
- if use_controlnet:
1541
- components, denoise_state = self._controlnet_denoise(components, denoise_state)
1542
- else:
1543
- components, denoise_state = self._denoise(components, denoise_state)
1544
-
1545
- block_state.latents = denoise_state.get("latents")
1546
- return components, block_state
1547
-
1548
-
1549
- # ---------------------------------------------------------------------------
1550
- # Loop sub-block 3: Decode + paste into canvas
1551
- # ---------------------------------------------------------------------------
1552
-
1553
- class UltimateSDUpscaleTilePostProcessStep(ModularPipelineBlocks):
1554
- """Loop sub-block that decodes one tile and pastes the core into the canvas.
1555
-
1556
- Supports two blending modes:
1557
- - ``"none"``: Non-overlapping core paste (fastest, default).
1558
- - ``"gradient"``: Gradient overlap blending for smoother tile transitions.
1559
- """
1560
-
1561
- model_name = "stable-diffusion-xl"
1562
-
1563
- def __init__(self):
1564
- super().__init__()
1565
- self._decode = StableDiffusionXLDecodeStep()
1566
-
1567
- @property
1568
- def description(self) -> str:
1569
- return (
1570
- "Loop sub-block: decodes latents to an image via StableDiffusionXLDecodeStep, "
1571
- "then extracts the core region and pastes it into the output canvas. "
1572
- "Supports 'none' and 'gradient' blending modes."
1573
- )
1574
-
1575
- @property
1576
- def expected_components(self) -> list[ComponentSpec]:
1577
- return [
1578
- ComponentSpec("vae", AutoencoderKL),
1579
- ComponentSpec(
1580
- "image_processor",
1581
- VaeImageProcessor,
1582
- config=FrozenDict({"vae_scale_factor": 8}),
1583
- default_creation_method="from_config",
1584
- ),
1585
- ]
1586
-
1587
- @property
1588
- def inputs(self) -> list[InputParam]:
1589
- return [
1590
- InputParam("latents", type_hint=torch.Tensor, required=True),
1591
- ]
1592
-
1593
- @property
1594
- def intermediate_outputs(self) -> list[OutputParam]:
1595
- return [] # Canvas is modified in-place on block_state
1596
-
1597
- @torch.no_grad()
1598
- def __call__(self, components, block_state: BlockState, tile_idx: int, tile: TileSpec):
1599
- decode_state = _make_state({
1600
- "latents": block_state.latents,
1601
- "output_type": "np",
1602
- })
1603
- components, decode_state = self._decode(components, decode_state)
1604
- decoded_images = decode_state.get("images")
1605
-
1606
- decoded_np = decoded_images[0] # shape: (crop_h, crop_w, 3)
1607
-
1608
- if decoded_np.shape[0] != tile.crop_h or decoded_np.shape[1] != tile.crop_w:
1609
- pil_tile = PIL.Image.fromarray((np.clip(decoded_np, 0, 1) * 255).astype(np.uint8))
1610
- pil_tile = pil_tile.resize((tile.crop_w, tile.crop_h), PIL.Image.LANCZOS)
1611
- decoded_np = np.array(pil_tile).astype(np.float32) / 255.0
1612
-
1613
- core = extract_core_from_decoded(decoded_np, tile)
1614
-
1615
- blend_mode = getattr(block_state, "blend_mode", "none")
1616
- if blend_mode == "gradient":
1617
- overlap = getattr(block_state, "gradient_blend_overlap", 0)
1618
- paste_core_into_canvas_blended(
1619
- block_state.canvas, block_state.weight_map, core, tile, overlap
1620
- )
1621
- elif blend_mode == "none":
1622
- paste_core_into_canvas(block_state.canvas, core, tile)
1623
- else:
1624
- raise ValueError(
1625
- f"Unsupported blend_mode '{blend_mode}'. "
1626
- "Supported modes: 'none', 'gradient'."
1627
- )
1628
-
1629
- return components, block_state
1630
-
1631
-
1632
- # ---------------------------------------------------------------------------
1633
- # Tile loop wrapper (LoopSequentialPipelineBlocks)
1634
- # ---------------------------------------------------------------------------
1635
-
1636
- class UltimateSDUpscaleTileLoopStep(LoopSequentialPipelineBlocks):
1637
- """Tile loop that iterates over the tile plan, running sub-blocks per tile.
1638
-
1639
- Supports:
1640
- - Two blending modes: ``"none"`` (core paste) and ``"gradient"`` (overlap blending)
1641
- - Optional seam-fix pass: re-denoises narrow bands along tile boundaries
1642
- with feathered mask blending
1643
-
1644
- Sub-blocks:
1645
- - ``UltimateSDUpscaleTilePrepareStep`` – crop, encode, prepare
1646
- - ``UltimateSDUpscaleTileDenoiserStep`` – denoising loop
1647
- - ``UltimateSDUpscaleTilePostProcessStep`` – decode + paste
1648
- """
1649
-
1650
- model_name = "stable-diffusion-xl"
1651
-
1652
- block_classes = [
1653
- UltimateSDUpscaleTilePrepareStep,
1654
- UltimateSDUpscaleTileDenoiserStep,
1655
- UltimateSDUpscaleTilePostProcessStep,
1656
- ]
1657
- block_names = ["tile_prepare", "tile_denoise", "tile_postprocess"]
1658
-
1659
- @property
1660
- def description(self) -> str:
1661
- return (
1662
- "Tile loop that iterates over the tile plan and runs sub-blocks per tile.\n"
1663
- "Supports 'none' and 'gradient' blending modes, plus optional seam-fix pass.\n"
1664
- "Sub-blocks:\n"
1665
- " - UltimateSDUpscaleTilePrepareStep: crop, VAE encode, set timesteps, "
1666
- "prepare latents, tile-aware add_cond\n"
1667
- " - UltimateSDUpscaleTileDenoiserStep: SDXL denoising loop\n"
1668
- " - UltimateSDUpscaleTilePostProcessStep: decode + paste core into canvas"
1669
- )
1670
-
1671
- @property
1672
- def loop_inputs(self) -> list[InputParam]:
1673
- return [
1674
- InputParam("tile_plan", type_hint=list, required=True,
1675
- description="List of TileSpec from the tile planning step."),
1676
- InputParam("upscaled_image", type_hint=PIL.Image.Image, required=True),
1677
- InputParam("upscaled_height", type_hint=int, required=True),
1678
- InputParam("upscaled_width", type_hint=int, required=True),
1679
- InputParam("tile_padding", type_hint=int, default=32),
1680
- InputParam("output_type", type_hint=str, default="pil"),
1681
- InputParam("blend_mode", type_hint=str, default="none",
1682
- description="Blending mode: 'none' (core paste) or 'gradient' (overlap blending)."),
1683
- InputParam("gradient_blend_overlap", type_hint=int, default=16,
1684
- description="Width of gradient ramp in pixels for 'gradient' blend mode."),
1685
- InputParam("seam_fix_plan", type_hint=list, default=[],
1686
- description="List of SeamFixSpec from tile planning. Empty disables seam fix."),
1687
- InputParam("seam_fix_mask_blur", type_hint=int, default=8,
1688
- description="Feathering width for seam-fix band blending."),
1689
- InputParam("seam_fix_strength", type_hint=float, default=0.3,
1690
- description="Denoise strength for seam-fix bands."),
1691
- InputParam("control_image",
1692
- description="Optional ControlNet conditioning image. If provided, tile denoising uses ControlNet."),
1693
- InputParam("control_guidance_start", default=0.0),
1694
- InputParam("control_guidance_end", default=1.0),
1695
- InputParam("controlnet_conditioning_scale", default=1.0),
1696
- InputParam("guess_mode", default=False),
1697
- InputParam("guidance_scale", type_hint=float, default=7.5,
1698
- description="Classifier-Free Guidance scale. Higher values produce images more aligned "
1699
- "with the prompt at the expense of lower image quality."),
1700
- ]
1701
-
1702
- @property
1703
- def loop_intermediate_outputs(self) -> list[OutputParam]:
1704
- return [
1705
- OutputParam("images", type_hint=list, description="Final stitched output images."),
1706
- ]
1707
-
1708
- def _run_seam_fix_band(self, components, block_state, band: SeamFixSpec, band_idx: int):
1709
- """Re-denoise one seam-fix band and blend it into the canvas."""
1710
- # Crop the band region directly from the float canvas to avoid
1711
- # full-canvas uint8 quantization per band (quality + perf).
1712
- crop_region = np.clip(
1713
- block_state.canvas[band.crop_y:band.crop_y + band.crop_h,
1714
- band.crop_x:band.crop_x + band.crop_w],
1715
- 0, 1,
1716
- )
1717
- crop_uint8 = (crop_region * 255).astype(np.uint8)
1718
- band_crop_pil = PIL.Image.fromarray(crop_uint8)
1719
-
1720
- # The PIL image is the crop region only, so the tile spec must use
1721
- # 0-based coordinates (the entire image IS the crop).
1722
- band_tile = TileSpec(
1723
- core_x=band.paste_x, core_y=band.paste_y,
1724
- core_w=band.band_w, core_h=band.band_h,
1725
- crop_x=0, crop_y=0,
1726
- crop_w=band.crop_w, crop_h=band.crop_h,
1727
- paste_x=band.paste_x, paste_y=band.paste_y,
1728
- )
1729
 
1730
- # Store original upscaled_image and swap in the band crop
1731
- original_image = block_state.upscaled_image
1732
- block_state.upscaled_image = band_crop_pil
1733
- original_control_image = getattr(block_state, "control_image_processed", None)
1734
- if getattr(block_state, "use_controlnet", False) and original_control_image is not None:
1735
- block_state.control_image_processed = original_control_image.crop(
1736
- (band.crop_x, band.crop_y, band.crop_x + band.crop_w, band.crop_y + band.crop_h)
1737
- )
 
 
 
 
 
1738
 
1739
- # Override strength for seam fix
1740
- original_strength = block_state.strength
1741
- block_state.strength = getattr(block_state, "seam_fix_strength", 0.3)
1742
 
1743
- # Run prepare + denoise (reuse existing sub-blocks)
1744
- prepare_block = self.sub_blocks["tile_prepare"]
1745
- denoise_block = self.sub_blocks["tile_denoise"]
1746
 
1747
- components, block_state = prepare_block(components, block_state, tile_idx=band_idx, tile=band_tile)
1748
- components, block_state = denoise_block(components, block_state, tile_idx=band_idx, tile=band_tile)
1749
 
1750
- # Decode the band
1751
- decode_state = _make_state({
1752
- "latents": block_state.latents,
1753
- "output_type": "np",
1754
- })
1755
- decode_block = self.sub_blocks["tile_postprocess"]._decode
1756
- components, decode_state = decode_block(components, decode_state)
1757
- decoded_np = decode_state.get("images")[0]
1758
 
1759
- if decoded_np.shape[0] != band.crop_h or decoded_np.shape[1] != band.crop_w:
1760
- pil_band = PIL.Image.fromarray((np.clip(decoded_np, 0, 1) * 255).astype(np.uint8))
1761
- pil_band = pil_band.resize((band.crop_w, band.crop_h), PIL.Image.LANCZOS)
1762
- decoded_np = np.array(pil_band).astype(np.float32) / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1763
 
1764
- # Extract and paste band with feathered mask
1765
- band_pixels = extract_band_from_decoded(decoded_np, band)
1766
- seam_fix_mask_blur = getattr(block_state, "seam_fix_mask_blur", 8)
1767
- paste_seam_fix_band(block_state.canvas, band_pixels, band, seam_fix_mask_blur)
1768
 
1769
- # Restore original values
1770
- block_state.upscaled_image = original_image
1771
- if getattr(block_state, "use_controlnet", False):
1772
- block_state.control_image_processed = original_control_image
1773
- block_state.strength = original_strength
1774
 
1775
- return components, block_state
1776
 
1777
- @torch.no_grad()
1778
- def __call__(self, components, state: PipelineState) -> PipelineState:
1779
- block_state = self.get_block_state(state)
1780
 
1781
- tile_plan = block_state.tile_plan
1782
- h = block_state.upscaled_height
1783
- w = block_state.upscaled_width
1784
- output_type = block_state.output_type
1785
- blend_mode = getattr(block_state, "blend_mode", "none")
1786
- if blend_mode not in ("none", "gradient"):
1787
- raise ValueError(
1788
- f"Unsupported blend_mode '{blend_mode}'. Supported: 'none', 'gradient'."
1789
- )
1790
 
1791
- # --- Configure guidance_scale on guider ---
1792
- guidance_scale = getattr(block_state, "guidance_scale", 7.5)
1793
- components.guider.guidance_scale = guidance_scale
1794
 
1795
- control_image = getattr(block_state, "control_image", None)
1796
- block_state.use_controlnet = control_image is not None
1797
- if block_state.use_controlnet:
1798
- if isinstance(control_image, list):
 
 
 
 
 
1799
  raise ValueError(
1800
- "MultiDiffusion currently supports a single `control_image`, not a list."
1801
  )
1802
- if not hasattr(components, "controlnet") or components.controlnet is None:
 
 
 
 
 
 
 
 
1803
  raise ValueError(
1804
- "`control_image` was provided but `controlnet` component is missing. "
1805
- "Load a ControlNet model (for example, a tile model) into `pipe.controlnet`."
1806
  )
1807
- block_state.control_image_processed = _to_pil_rgb_image(control_image)
1808
- if block_state.control_image_processed.size != (w, h):
1809
- block_state.control_image_processed = block_state.control_image_processed.resize((w, h), PIL.Image.LANCZOS)
1810
- logger.info("ControlNet conditioning enabled for tiled denoising.")
1811
-
1812
- # Enable VAE tiling for memory-efficient encode/decode of large images.
1813
- # This lets the UNet process the full image (no tile seams) while the
1814
- # VAE handles memory via its own internal tiling.
1815
- if hasattr(components.vae, "enable_tiling"):
1816
- components.vae.enable_tiling()
 
 
 
 
 
 
 
 
 
 
 
1817
 
1818
- # Initialize canvas
1819
- block_state.canvas = np.zeros((h, w, 3), dtype=np.float32)
 
1820
 
1821
- # Prepare one global latent noise tensor and crop from it per tile.
1822
- # This keeps stochasticity consistent across tile boundaries.
1823
- vae_scale_factor = int(getattr(components, "vae_scale_factor", 8))
1824
- latent_h = max(1, h // vae_scale_factor)
1825
- latent_w = max(1, w // vae_scale_factor)
1826
- effective_batch = block_state.batch_size * block_state.num_images_per_prompt
1827
- block_state.global_noise_map = randn_tensor(
1828
- (effective_batch, 4, latent_h, latent_w),
1829
- generator=getattr(block_state, "generator", None),
1830
- device=components._execution_device,
1831
- dtype=block_state.dtype,
1832
- )
1833
- block_state.global_noise_scale = vae_scale_factor
1834
 
1835
- if blend_mode == "gradient":
1836
- block_state.weight_map = np.zeros((h, w), dtype=np.float32)
1837
- block_state.blend_mode = blend_mode
1838
- block_state.gradient_blend_overlap = getattr(block_state, "gradient_blend_overlap", 16)
1839
 
1840
- num_tiles = len(tile_plan)
1841
- seam_fix_plan = getattr(block_state, "seam_fix_plan", []) or []
1842
- total_steps = num_tiles + len(seam_fix_plan)
 
 
 
 
 
1843
 
1844
- logger.info(
1845
- f"Processing {num_tiles} tiles"
1846
- + (f" (blend_mode={blend_mode})" if blend_mode != "none" else "")
1847
- + (f" + {len(seam_fix_plan)} seam-fix bands" if seam_fix_plan else "")
1848
- )
1849
 
1850
- with self.progress_bar(total=total_steps) as progress_bar:
1851
- # Main tile loop
1852
- for i, tile in enumerate(tile_plan):
1853
- logger.debug(
1854
- f"Tile {i + 1}/{num_tiles}: core=({tile.core_x},{tile.core_y},{tile.core_w},{tile.core_h}) "
1855
- f"crop=({tile.crop_x},{tile.crop_y},{tile.crop_w},{tile.crop_h})"
1856
- )
1857
- components, block_state = self.loop_step(components, block_state, tile_idx=i, tile=tile)
1858
- progress_bar.update()
1859
-
1860
- # Finalize gradient blending before seam fix
1861
- if blend_mode == "gradient":
1862
- block_state.canvas = finalize_blended_canvas(block_state.canvas, block_state.weight_map)
1863
-
1864
- # Seam-fix pass
1865
- for j, band in enumerate(seam_fix_plan):
1866
- logger.debug(
1867
- f"Seam-fix {j + 1}/{len(seam_fix_plan)}: "
1868
- f"band=({band.band_x},{band.band_y},{band.band_w},{band.band_h}) "
1869
- f"{band.orientation}"
1870
- )
1871
- components, block_state = self._run_seam_fix_band(components, block_state, band, j)
1872
- progress_bar.update()
1873
 
1874
- # Finalize output
1875
- result = np.clip(block_state.canvas, 0.0, 1.0)
1876
- result_uint8 = (result * 255).astype(np.uint8)
 
1877
 
1878
- if output_type == "pil":
1879
- block_state.images = [PIL.Image.fromarray(result_uint8)]
1880
- elif output_type == "np":
1881
- block_state.images = [result]
1882
- elif output_type == "pt":
1883
- block_state.images = [torch.from_numpy(result).permute(2, 0, 1).unsqueeze(0)]
1884
- else:
1885
- block_state.images = [PIL.Image.fromarray(result_uint8)]
1886
 
1887
- self.set_block_state(state, block_state)
1888
- return components, state
 
1889
 
 
1890
 
1891
- # =============================================================================
1892
- # MultiDiffusion: latent-space noise prediction blending
1893
- # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1894
 
1895
 
1896
- def _make_cosine_tile_weight(
1897
- h: int, w: int, overlap: int, device, dtype,
1898
- is_top: bool = False, is_bottom: bool = False,
1899
- is_left: bool = False, is_right: bool = False,
1900
- ) -> torch.Tensor:
1901
- """Create a boundary-aware 2D cosine-ramp weight for MultiDiffusion blending.
1902
-
1903
- Weight is 1.0 in the center and smoothly fades at edges that overlap with
1904
- neighboring tiles. Edges that touch the image boundary keep weight=1.0 to
1905
- prevent noise amplification from dividing by near-zero weights.
1906
-
1907
- Args:
1908
- h: Tile height in latent pixels.
1909
- w: Tile width in latent pixels.
1910
- overlap: Overlap in latent pixels.
1911
- device: Torch device.
1912
- dtype: Torch dtype.
1913
- is_top: True if this tile touches the top image boundary.
1914
- is_bottom: True if this tile touches the bottom image boundary.
1915
- is_left: True if this tile touches the left image boundary.
1916
- is_right: True if this tile touches the right image boundary.
1917
-
1918
- Returns:
1919
- Tensor of shape ``(1, 1, h, w)`` for broadcasting.
1920
- """
1921
- def _ramp(length, overlap_size, keep_start, keep_end):
1922
- ramp = torch.ones(length, device=device, dtype=dtype)
1923
- if overlap_size > 0 and length > 2 * overlap_size:
1924
- fade = 0.5 * (1.0 - torch.cos(torch.linspace(0, math.pi, overlap_size, device=device, dtype=dtype)))
1925
- if not keep_start:
1926
- ramp[:overlap_size] = fade
1927
- if not keep_end:
1928
- ramp[-overlap_size:] = fade.flip(0)
1929
- return ramp
1930
 
1931
- w_h = _ramp(h, overlap, keep_start=is_top, keep_end=is_bottom)
1932
- w_w = _ramp(w, overlap, keep_start=is_left, keep_end=is_right)
1933
- return (w_h[:, None] * w_w[None, :]).unsqueeze(0).unsqueeze(0)
1934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1935
 
1936
  class UltimateSDUpscaleMultiDiffusionStep(ModularPipelineBlocks):
1937
  """Single block that encodes, denoises with MultiDiffusion, and decodes.
@@ -2580,42 +1102,7 @@ class UltimateSDUpscaleMultiDiffusionStep(ModularPipelineBlocks):
2580
  # modular_blocks
2581
  # ============================================================
2582
 
2583
- # Copyright 2025 The HuggingFace Team. All rights reserved.
2584
- #
2585
- # Licensed under the Apache License, Version 2.0 (the "License");
2586
- # you may not use this file except in compliance with the License.
2587
- # You may obtain a copy of the License at
2588
- #
2589
- # http://www.apache.org/licenses/LICENSE-2.0
2590
- #
2591
- # Unless required by applicable law or agreed to in writing, software
2592
- # distributed under the License is distributed on an "AS IS" BASIS,
2593
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2594
- # See the License for the specific language governing permissions and
2595
- # limitations under the License.
2596
-
2597
- """Top-level block composition for Modular SDXL Upscale.
2598
-
2599
- The pipeline preserves the standard SDXL block graph as closely as
2600
- possible, inserting upscale and tile-plan steps and wrapping the per-tile
2601
- work in a ``LoopSequentialPipelineBlocks``::
2602
-
2603
- text_encoder → upscale → tile_plan → input → set_timesteps → tiled_img2img
2604
-
2605
- Inside ``tiled_img2img`` (tile loop), each tile runs:
2606
-
2607
- tile_prepare → tile_denoise → tile_postprocess
2608
-
2609
- Followed by an optional seam-fix pass that re-denoises narrow bands along
2610
- tile boundaries with feathered mask blending.
2611
-
2612
- Features:
2613
- - Linear and chess (checkerboard) tile traversal
2614
- - Non-overlapping core paste or gradient overlap blending
2615
- - Optional seam-fix band re-denoise with configurable width and mask blur
2616
- - Optional ControlNet tile conditioning for stronger cross-tile structure consistency
2617
- - Tile-aware SDXL micro-conditioning (crops_coords_top_left per tile)
2618
- """
2619
 
2620
  from diffusers.utils import logging
2621
  from diffusers.modular_pipelines.modular_pipeline import SequentialPipelineBlocks
@@ -2629,85 +1116,22 @@ from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import (
2629
  logger = logging.get_logger(__name__)
2630
 
2631
 
2632
- class UltimateSDUpscaleBlocks(SequentialPipelineBlocks):
2633
- """Modular pipeline blocks for tiled SDXL upscaling.
2634
-
2635
- Block graph::
2636
-
2637
- [0] text_encoder – StableDiffusionXLTextEncoderStep (reused)
2638
- [1] upscale – UltimateSDUpscaleUpscaleStep (new)
2639
- [2] tile_plan – UltimateSDUpscaleTilePlanStep (new)
2640
- [3] input – StableDiffusionXLInputStep (reused)
2641
- [4] set_timesteps – StableDiffusionXLImg2ImgSetTimestepsStep (reused)
2642
- [5] tiled_img2img – UltimateSDUpscaleTileLoopStep (tile loop + seam fix)
2643
-
2644
- Features:
2645
- - Linear and chess (checkerboard) tile traversal
2646
- - Non-overlapping core paste or gradient overlap blending
2647
- - Seam-fix band re-denoise with feathered mask blending
2648
- - Tile-aware SDXL conditioning (crops_coords_top_left per tile)
2649
- """
2650
-
2651
- block_classes = [
2652
- UltimateSDUpscaleTextEncoderStep,
2653
- UltimateSDUpscaleUpscaleStep,
2654
- UltimateSDUpscaleTilePlanStep,
2655
- StableDiffusionXLInputStep,
2656
- StableDiffusionXLImg2ImgSetTimestepsStep,
2657
- UltimateSDUpscaleTileLoopStep,
2658
- ]
2659
- block_names = [
2660
- "text_encoder",
2661
- "upscale",
2662
- "tile_plan",
2663
- "input",
2664
- "set_timesteps",
2665
- "tiled_img2img",
2666
- ]
2667
-
2668
- _workflow_map = {
2669
- "upscale": {"image": True, "prompt": True},
2670
- "upscale_controlnet": {"image": True, "control_image": True, "prompt": True},
2671
- }
2672
-
2673
- @property
2674
- def description(self):
2675
- return (
2676
- "Modular tiled upscaling pipeline for Stable Diffusion XL.\n"
2677
- "Upscales an input image and refines it using tiled denoising.\n"
2678
- "Default: single-pass mode (tile_size=2048) — seamless, no tile artifacts.\n"
2679
- "For very large images: set tile_size=512 for tiled mode with optional "
2680
- "chess traversal, gradient blending, seam-fix, and ControlNet tile conditioning."
2681
- )
2682
-
2683
- @property
2684
- def outputs(self):
2685
- return [OutputParam.template("images")]
2686
-
2687
-
2688
  class MultiDiffusionUpscaleBlocks(SequentialPipelineBlocks):
2689
  """Modular pipeline blocks for tiled SDXL upscaling with MultiDiffusion.
2690
 
2691
  Uses latent-space noise prediction blending across overlapping tiles for
2692
- **seamless** tiled upscaling at any resolution. This is the recommended
2693
- block set for high-quality upscaling.
2694
 
2695
  Block graph::
2696
 
2697
- [0] text_encoder StableDiffusionXLTextEncoderStep (reused)
2698
- [1] upscale UltimateSDUpscaleUpscaleStep (Lanczos resize)
2699
- [2] input StableDiffusionXLInputStep (reused)
2700
- [3] set_timesteps StableDiffusionXLImg2ImgSetTimestepsStep (reused)
2701
- [4] multidiffusion UltimateSDUpscaleMultiDiffusionStep (NEW)
2702
 
2703
  The MultiDiffusion step handles VAE encode, tiled denoise with blending,
2704
  and VAE decode internally, using VAE tiling for memory efficiency.
2705
-
2706
- Features:
2707
- - Seamless output at any resolution (no tile boundary artifacts)
2708
- - Optional ControlNet Tile conditioning
2709
- - Configurable latent tile size and overlap
2710
- - Single-pass for small images, tiled for large images
2711
  """
2712
 
2713
  block_classes = [
 
8
  # utils_tiling
9
  # ============================================================
10
 
11
+ """Tile planning and cosine blending weights for MultiDiffusion."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ from dataclasses import dataclass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ import torch
 
 
 
 
 
 
16
 
17
 
18
  @dataclass
19
  class LatentTileSpec:
20
+ """Tile specification in latent space.
21
 
22
  Attributes:
23
  y: Top edge in latent pixels.
 
32
  w: int
33
 
34
 
35
+ def validate_tile_params(tile_size: int, overlap: int) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if tile_size <= 0:
37
  raise ValueError(f"`tile_size` must be positive, got {tile_size}.")
38
  if overlap < 0:
 
43
  f"Got overlap={overlap}, tile_size={tile_size}."
44
  )
45
 
46
+
47
+ def plan_latent_tiles(
48
+ latent_h: int,
49
+ latent_w: int,
50
+ tile_size: int = 64,
51
+ overlap: int = 8,
52
+ ) -> list[LatentTileSpec]:
53
+ """Plan overlapping tiles in latent space for MultiDiffusion.
54
+
55
+ Tiles overlap by ``overlap`` latent pixels. Edge tiles are clamped to
56
+ the latent bounds.
57
+ """
58
+ validate_tile_params(tile_size, overlap)
59
+
60
  stride = tile_size - overlap
61
  tiles: list[LatentTileSpec] = []
62
 
63
  y = 0
64
  while y < latent_h:
65
  h = min(tile_size, latent_h - y)
 
66
  if h < tile_size and y > 0:
67
  y = max(0, latent_h - tile_size)
68
  h = latent_h - y
 
87
  return tiles
88
 
89
 
90
+ def make_cosine_tile_weight(
91
+ h: int,
92
+ w: int,
93
+ overlap: int,
94
+ device: torch.device,
95
+ dtype: torch.dtype,
96
+ is_top: bool = False,
97
+ is_bottom: bool = False,
98
+ is_left: bool = False,
99
+ is_right: bool = False,
100
+ ) -> torch.Tensor:
101
+ """Boundary-aware cosine blending weight for one tile.
102
+
103
+ Returns shape (1, 1, h, w). Canvas-edge sides get weight 1.0 (no fade),
104
+ interior overlap regions get a half-cosine ramp from 0 to 1.
105
+ """
106
+ import math
107
+
108
+ wy = torch.ones(h, device=device, dtype=dtype)
109
+ wx = torch.ones(w, device=device, dtype=dtype)
110
+
111
+ ramp = min(overlap, h // 2, w // 2)
112
+ if ramp <= 0:
113
+ return torch.ones(1, 1, h, w, device=device, dtype=dtype)
114
+
115
+ cos_ramp = torch.tensor(
116
+ [0.5 * (1 - math.cos(math.pi * i / ramp)) for i in range(ramp)],
117
+ device=device,
118
+ dtype=dtype,
119
+ )
120
+
121
+ if not is_top:
122
+ wy[:ramp] = cos_ramp
123
+ if not is_bottom:
124
+ wy[-ramp:] = cos_ramp.flip(0)
125
+ if not is_left:
126
+ wx[:ramp] = cos_ramp
127
+ if not is_right:
128
+ wx[-ramp:] = cos_ramp.flip(0)
129
+
130
+ weight = wy[:, None] * wx[None, :]
131
+ return weight.unsqueeze(0).unsqueeze(0)
132
+
133
+
134
  # ============================================================
135
  # input
136
  # ============================================================
 
149
  # See the License for the specific language governing permissions and
150
  # limitations under the License.
151
 
152
+ """Input steps for Modular SDXL Upscale: text encoding, Lanczos upscale."""
153
+
154
  import PIL.Image
155
  import torch
156
 
 
166
  class UltimateSDUpscaleTextEncoderStep(StableDiffusionXLTextEncoderStep):
167
  """SDXL text encoder step that applies guidance scale before encoding.
168
 
169
+ Syncs the guider's guidance_scale before prompt encoding so that
170
+ unconditional embeddings are always produced when CFG is active.
 
171
 
172
+ Also applies a default negative prompt for upscaling when the user
173
+ does not provide one.
 
 
 
 
 
174
  """
175
 
176
  DEFAULT_NEGATIVE_PROMPT = "blurry, low quality, artifacts, noise, jpeg compression"
177
 
178
  @property
179
  def inputs(self) -> list[InputParam]:
 
180
  return super().inputs + [
181
  InputParam(
182
  "guidance_scale",
183
  type_hint=float,
184
  default=7.5,
185
+ description="Classifier-Free Guidance scale.",
 
 
 
186
  ),
187
  InputParam(
188
  "use_default_negative",
189
  type_hint=bool,
190
  default=True,
191
+ description="Apply default negative prompt when none is provided.",
 
 
 
 
192
  ),
193
  ]
194
 
 
200
  if hasattr(components, "guider") and components.guider is not None:
201
  components.guider.guidance_scale = guidance_scale
202
 
 
203
  use_default_negative = getattr(block_state, "use_default_negative", True)
204
  if use_default_negative:
205
  neg = getattr(block_state, "negative_prompt", None)
 
211
 
212
 
213
  class UltimateSDUpscaleUpscaleStep(ModularPipelineBlocks):
214
+ """Upscales the input image using Lanczos interpolation."""
 
 
 
 
 
215
 
216
  @property
217
  def description(self) -> str:
218
+ return "Upscale input image using Lanczos interpolation."
 
 
 
 
219
 
220
  @property
221
  def inputs(self) -> list[InputParam]:
222
  return [
223
+ InputParam("image", type_hint=PIL.Image.Image, required=True,
224
+ description="Input image to upscale."),
225
+ InputParam("upscale_factor", type_hint=float, default=2.0,
226
+ description="Scale multiplier."),
 
 
 
 
 
 
 
 
227
  ]
228
 
229
  @property
230
  def intermediate_outputs(self) -> list[OutputParam]:
231
  return [
232
+ OutputParam("upscaled_image", type_hint=PIL.Image.Image),
233
+ OutputParam("upscaled_width", type_hint=int),
234
+ OutputParam("upscaled_height", type_hint=int),
 
 
 
 
 
 
 
 
 
 
 
 
235
  ]
236
 
237
  @torch.no_grad()
 
242
  upscale_factor = block_state.upscale_factor
243
 
244
  if not isinstance(image, PIL.Image.Image):
245
+ raise ValueError(f"Expected PIL.Image, got {type(image)}.")
 
 
 
246
 
247
  new_width = int(image.width * upscale_factor)
248
  new_height = int(image.height * upscale_factor)
 
251
  block_state.upscaled_width = new_width
252
  block_state.upscaled_height = new_height
253
 
254
+ logger.info(f"Upscaled {image.width}x{image.height} -> {new_width}x{new_height}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  self.set_block_state(state, block_state)
257
  return components, state
258
 
259
 
260
+ # ============================================================
261
+ # denoise
262
+ # ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
265
+ #
266
+ # Licensed under the Apache License, Version 2.0 (the "License");
267
+ # you may not use this file except in compliance with the License.
268
+ # You may obtain a copy of the License at
269
+ #
270
+ # http://www.apache.org/licenses/LICENSE-2.0
271
+ #
272
+ # Unless required by applicable law or agreed to in writing, software
273
+ # distributed under the License is distributed on an "AS IS" BASIS,
274
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
275
+ # See the License for the specific language governing permissions and
276
+ # limitations under the License.
277
 
278
+ """MultiDiffusion tiled upscaling step for Modular SDXL Upscale.
 
 
279
 
280
+ Blends noise predictions from overlapping latent tiles using cosine weights.
281
+ Reuses SDXL blocks via their public interface.
282
+ """
283
 
284
+ import math
285
+ import time
286
 
287
+ import numpy as np
288
+ import PIL.Image
289
+ import torch
290
+ from tqdm.auto import tqdm
 
 
 
 
291
 
292
+ from diffusers.configuration_utils import FrozenDict
293
+ from diffusers.guiders import ClassifierFreeGuidance
294
+ from diffusers.image_processor import VaeImageProcessor
295
+ from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
296
+ from diffusers.schedulers import DPMSolverMultistepScheduler, EulerDiscreteScheduler
297
+ from diffusers.utils import logging
298
+ from diffusers.utils.torch_utils import randn_tensor
299
+ from diffusers.modular_pipelines.modular_pipeline import (
300
+ ModularPipelineBlocks,
301
+ PipelineState,
302
+ )
303
+ from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
304
+ from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import (
305
+ StableDiffusionXLControlNetInputStep,
306
+ StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
307
+ StableDiffusionXLImg2ImgPrepareLatentsStep,
308
+ prepare_latents_img2img,
309
+ )
310
+ from diffusers.modular_pipelines.stable_diffusion_xl.decoders import StableDiffusionXLDecodeStep
311
+ from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLVaeEncoderStep
312
 
 
 
 
 
313
 
314
+ logger = logging.get_logger(__name__)
 
 
 
 
315
 
 
316
 
317
+ # ---------------------------------------------------------------------------
318
+ # Helper: populate a PipelineState from a dict
319
+ # ---------------------------------------------------------------------------
320
 
321
+ def _make_state(values: dict, kwargs_type_map: dict | None = None) -> PipelineState:
322
+ """Create a PipelineState and set values, optionally with kwargs_type."""
323
+ state = PipelineState()
324
+ kwargs_type_map = kwargs_type_map or {}
325
+ for k, v in values.items():
326
+ state.set(k, v, kwargs_type_map.get(k))
327
+ return state
 
 
328
 
 
 
 
329
 
330
+ def _to_pil_rgb_image(image) -> PIL.Image.Image:
331
+ """Convert a tensor/ndarray/PIL image to a RGB PIL image."""
332
+ if isinstance(image, PIL.Image.Image):
333
+ return image.convert("RGB")
334
+
335
+ if torch.is_tensor(image):
336
+ tensor = image.detach().cpu()
337
+ if tensor.ndim == 4:
338
+ if tensor.shape[0] != 1:
339
  raise ValueError(
340
+ f"`control_image` tensor batch must be 1 for tiled upscaling, got shape {tuple(tensor.shape)}."
341
  )
342
+ tensor = tensor[0]
343
+ if tensor.ndim == 3 and tensor.shape[0] in (1, 3, 4) and tensor.shape[-1] not in (1, 3, 4):
344
+ tensor = tensor.permute(1, 2, 0)
345
+ image = tensor.numpy()
346
+
347
+ if isinstance(image, np.ndarray):
348
+ array = image
349
+ if array.ndim == 4:
350
+ if array.shape[0] != 1:
351
  raise ValueError(
352
+ f"`control_image` ndarray batch must be 1 for tiled upscaling, got shape {array.shape}."
 
353
  )
354
+ array = array[0]
355
+ if array.ndim == 3 and array.shape[0] in (1, 3, 4) and array.shape[-1] not in (1, 3, 4):
356
+ array = np.transpose(array, (1, 2, 0))
357
+ if array.ndim == 2:
358
+ array = np.stack([array] * 3, axis=-1)
359
+ if array.ndim != 3:
360
+ raise ValueError(f"`control_image` must have 2 or 3 dimensions, got shape {array.shape}.")
361
+ if array.shape[-1] == 1:
362
+ array = np.repeat(array, 3, axis=-1)
363
+ if array.shape[-1] == 4:
364
+ array = array[..., :3]
365
+ if array.shape[-1] != 3:
366
+ raise ValueError(f"`control_image` channel dimension must be 1/3/4, got shape {array.shape}.")
367
+ if array.dtype != np.uint8:
368
+ array = np.asarray(array, dtype=np.float32)
369
+ max_val = float(np.max(array)) if array.size > 0 else 1.0
370
+ if max_val <= 1.0:
371
+ array = (np.clip(array, 0.0, 1.0) * 255.0).astype(np.uint8)
372
+ else:
373
+ array = np.clip(array, 0.0, 255.0).astype(np.uint8)
374
+ return PIL.Image.fromarray(array).convert("RGB")
375
 
376
+ raise ValueError(
377
+ f"Unsupported `control_image` type {type(image)}. Expected PIL.Image, torch.Tensor, or numpy.ndarray."
378
+ )
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
+ # ---------------------------------------------------------------------------
382
+ # Scheduler swap helper (Feature 5)
383
+ # ---------------------------------------------------------------------------
 
384
 
385
+ _SCHEDULER_ALIASES = {
386
+ "euler": "EulerDiscreteScheduler",
387
+ "euler discrete": "EulerDiscreteScheduler",
388
+ "eulerdiscretescheduler": "EulerDiscreteScheduler",
389
+ "dpm++ 2m": "DPMSolverMultistepScheduler",
390
+ "dpmsolvermultistepscheduler": "DPMSolverMultistepScheduler",
391
+ "dpm++ 2m karras": "DPMSolverMultistepScheduler+karras",
392
+ }
393
 
 
 
 
 
 
394
 
395
+ def _swap_scheduler(components, scheduler_name: str):
396
+ """Swap the scheduler on ``components`` given a human-readable name.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
+ Supported names (case-insensitive):
399
+ - ``"Euler"`` / ``"EulerDiscreteScheduler"``
400
+ - ``"DPM++ 2M"`` / ``"DPMSolverMultistepScheduler"``
401
+ - ``"DPM++ 2M Karras"`` (DPMSolverMultistep with Karras sigmas)
402
 
403
+ If the requested scheduler is already active, this is a no-op.
404
+ """
405
+ key = scheduler_name.strip().lower()
406
+ resolved = _SCHEDULER_ALIASES.get(key, key)
 
 
 
 
407
 
408
+ use_karras = resolved.endswith("+karras")
409
+ if use_karras:
410
+ resolved = resolved.replace("+karras", "")
411
 
412
+ current = type(components.scheduler).__name__
413
 
414
+ if resolved == "EulerDiscreteScheduler":
415
+ if current != "EulerDiscreteScheduler":
416
+ components.scheduler = EulerDiscreteScheduler.from_config(components.scheduler.config)
417
+ logger.info("Swapped scheduler to EulerDiscreteScheduler")
418
+ elif resolved == "DPMSolverMultistepScheduler":
419
+ if current != "DPMSolverMultistepScheduler" or (
420
+ use_karras and not getattr(components.scheduler.config, "use_karras_sigmas", False)
421
+ ):
422
+ extra_kwargs = {}
423
+ if use_karras:
424
+ extra_kwargs["use_karras_sigmas"] = True
425
+ components.scheduler = DPMSolverMultistepScheduler.from_config(
426
+ components.scheduler.config, **extra_kwargs
427
+ )
428
+ logger.info(f"Swapped scheduler to DPMSolverMultistepScheduler (karras={use_karras})")
429
+ else:
430
+ logger.warning(
431
+ f"Unknown scheduler_name '{scheduler_name}'. Keeping current scheduler "
432
+ f"({current}). Supported: 'Euler', 'DPM++ 2M', 'DPM++ 2M Karras'."
433
+ )
434
 
435
 
436
+ # ---------------------------------------------------------------------------
437
+ # Auto-strength helper (Feature 2)
438
+ # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
+ def _compute_auto_strength(upscale_factor: float, pass_index: int, num_passes: int) -> float:
441
+ """Return the auto-scaled denoise strength for a given pass.
 
442
 
443
+ Rules:
444
+ - Single-pass 2x: 0.3
445
+ - Single-pass 4x: 0.15
446
+ - Progressive passes: first pass=0.3, subsequent passes=0.2
447
+ """
448
+ if num_passes > 1:
449
+ return 0.3 if pass_index == 0 else 0.2
450
+ # Single pass
451
+ if upscale_factor <= 2.0:
452
+ return 0.3
453
+ elif upscale_factor <= 4.0:
454
+ return 0.15
455
+ else:
456
+ return 0.1
457
 
458
  class UltimateSDUpscaleMultiDiffusionStep(ModularPipelineBlocks):
459
  """Single block that encodes, denoises with MultiDiffusion, and decodes.
 
1102
  # modular_blocks
1103
  # ============================================================
1104
 
1105
+ """Block composition for Modular SDXL Upscale."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1106
 
1107
  from diffusers.utils import logging
1108
  from diffusers.modular_pipelines.modular_pipeline import SequentialPipelineBlocks
 
1116
  logger = logging.get_logger(__name__)
1117
 
1118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1119
  class MultiDiffusionUpscaleBlocks(SequentialPipelineBlocks):
1120
  """Modular pipeline blocks for tiled SDXL upscaling with MultiDiffusion.
1121
 
1122
  Uses latent-space noise prediction blending across overlapping tiles for
1123
+ seamless tiled upscaling at any resolution.
 
1124
 
1125
  Block graph::
1126
 
1127
+ [0] text_encoder - SDXL TextEncoderStep (reused)
1128
+ [1] upscale - Lanczos resize
1129
+ [2] input - SDXL InputStep (reused)
1130
+ [3] set_timesteps - SDXL Img2Img SetTimestepsStep (reused)
1131
+ [4] multidiffusion - MultiDiffusion step
1132
 
1133
  The MultiDiffusion step handles VAE encode, tiled denoise with blending,
1134
  and VAE decode internally, using VAE tiling for memory efficiency.
 
 
 
 
 
 
1135
  """
1136
 
1137
  block_classes = [