Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
"""
Refactored USD Upscaler batch processing patch.
Preserves original behavior but:
- Organizes imports and helpers
- Replaces prints with logging
- Factors duplicated logic (tile preparation, batching, decoding)
- Uses functools.wraps when monkey-patching methods
- Adds type hints and docstrings for clarity
"""
from __future__ import annotations
from functools import wraps
import logging
import math
from typing import Tuple, List
from PIL import Image
import modules.shared as shared
from modules.processing import process_batch_tiles
from repositories import ultimate_upscale as usdu
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)
# Compatibility for older Pillow versions
try:
Image.Resampling # type: ignore
except Exception:
Image.Resampling = Image # type: ignore
# -------------------------
# Utility helpers
# -------------------------
def round_length(length: int, multiple: int = 8) -> int:
"""Round length to nearest multiple (default 8)."""
return round(length / multiple) * multiple
# -------------------------
# Monkey patches for USDUpscaler sizing / redraw / seams fix
# -------------------------
def patch_usdu_upscaler_init():
"""Patch USDUpscaler.__init__ to round upscaler p.width/p.height to multiples."""
old_init = usdu.USDUpscaler.__init__
@wraps(old_init)
def new_init(self, p, image, upscaler_index, save_redraw, save_seams_fix, tile_width, tile_height):
p.width = round_length(image.width * p.upscale_by)
p.height = round_length(image.height * p.upscale_by)
return old_init(self, p, image, upscaler_index, save_redraw, save_seams_fix, tile_width, tile_height)
usdu.USDUpscaler.__init__ = new_init
def patch_usdu_redraw_init():
"""Patch USDURedraw.init_draw to round tile size used for redraw."""
old_init_draw = usdu.USDURedraw.init_draw
@wraps(old_init_draw)
def new_init_draw(self, p, width, height):
mask, draw = old_init_draw(self, p, width, height)
p.width = round_length(self.tile_width + self.padding)
p.height = round_length(self.tile_height + self.padding)
return mask, draw
usdu.USDURedraw.init_draw = new_init_draw
def patch_usdu_seams_fix_init():
old_init = usdu.USDUSeamsFix.init_draw
@wraps(old_init)
def new_init(self, p):
old_init(self, p)
p.width = round_length(self.tile_width + self.padding)
p.height = round_length(self.tile_height + self.padding)
usdu.USDUSeamsFix.init_draw = new_init
def patch_usdu_upscale_method():
"""Patch USDUpscaler.upscale to keep shared.batch resized to p.width/p.height."""
old_upscale = usdu.USDUpscaler.upscale
@wraps(old_upscale)
def new_upscale(self):
old_upscale(self)
# Keep shared.batch consistent with the upscaling width/height for subsequent processing.
shared.batch = [self.image] + [
img.resize((self.p.width, self.p.height), resample=Image.LANCZOS)
for img in shared.batch[1:]
]
usdu.USDUpscaler.upscale = new_upscale
# Apply patches
patch_usdu_upscaler_init()
patch_usdu_redraw_init()
patch_usdu_seams_fix_init()
patch_usdu_upscale_method()
# -------------------------
# Patched script.run replacement
# -------------------------
def patched_script_run(self, p, _, tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, seams_fix_padding,
upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, seams_fix_mask_blur,
seams_fix_type, target_size_type, custom_width, custom_height, custom_scale):
"""
Replacement for usdu.Script.run that preserves the original batch_size
and delegates to the (patched) USDUpscaler and redraw pipeline.
"""
preserved_batch_size = getattr(p, 'batch_size', 1)
logger.info("[USDU Batch Debug] Patched script.run() preserving batch_size=%s", preserved_batch_size)
# Init (matching original code)
usdu.processing.fix_seed(p)
usdu.devices.torch_gc()
# Keep original file-saving flags as in original code
p.do_not_save_grid = True
p.do_not_save_samples = True
p.inpaint_full_res = False
p.inpainting_fill = 1
p.n_iter = 1
p.batch_size = preserved_batch_size
seed = p.seed
# Init image
init_img = p.init_images[0]
if init_img is None:
return usdu.processing.Processed(p, [], seed, "Empty image")
init_img = usdu.images.flatten(init_img, usdu.shared.opts.img2img_background_color)
# Override size by user choice
if target_size_type == 1:
p.width = custom_width
p.height = custom_height
elif target_size_type == 2:
p.width = math.ceil((init_img.width * custom_scale) / 64) * 64
p.height = math.ceil((init_img.height * custom_scale) / 64) * 64
# Create and run upscaler
upscaler = usdu.USDUpscaler(p, init_img, upscaler_index, save_upscaled_image, save_seams_fix_image, tile_width, tile_height)
upscaler.upscale()
# Drawing & seams fix setup
upscaler.setup_redraw(redraw_mode, padding, mask_blur)
upscaler.setup_seams_fix(seams_fix_padding, seams_fix_denoise, seams_fix_mask_blur, seams_fix_width, seams_fix_type)
upscaler.print_info()
upscaler.add_extra_info()
upscaler.process()
result_images = upscaler.result_images
logger.info("[USDU Batch Debug] Patched script.run() complete, batch_size=%s", p.batch_size)
return usdu.processing.Processed(p, result_images, seed, upscaler.initial_info or "")
# Replace the original script.run with patched version
usdu.Script.run = patched_script_run
# -------------------------
# Replace USDURedraw.linear_process and chess_process with batched variants
# -------------------------
def patch_usdu_linear_and_chess_process():
old_linear = usdu.USDURedraw.linear_process
old_chess = usdu.USDURedraw.chess_process
@wraps(old_linear)
def new_linear_process(self, p, image, rows, cols):
batch_size = getattr(p, 'batch_size', 1)
logger.info("[USDU Batch Debug] linear_process called batch_size=%s rows=%s cols=%s total_tiles=%s", batch_size, rows, cols, rows * cols)
if batch_size <= 1:
logger.info("[USDU Batch Debug] Using original single-tile processing (batch_size=%s)", batch_size)
return old_linear(self, p, image, rows, cols)
# Batch mode
mask_template, draw_template = self.init_draw(p, image.width, image.height)
tiles_to_process: List[Tuple[int, int]] = []
batch_count = 0
for yi in range(rows):
for xi in range(cols):
if shared.state.interrupted:
break
tiles_to_process.append((xi, yi))
if len(tiles_to_process) >= batch_size or (yi == rows - 1 and xi == cols - 1):
batch_count += 1
logger.info("[USDU Batch Debug] Processing batch #%s with %s tiles: %s", batch_count, len(tiles_to_process), tiles_to_process)
shared.batch = process_batch_tiles(p, tiles_to_process, shared.batch, self.calc_rectangle)
tiles_to_process = []
logger.info("[USDU Batch Debug] Linear processing complete. Processed %s batches total.", batch_count)
p.width = image.width
p.height = image.height
return image
@wraps(old_chess)
def new_chess_process(self, p, image, rows, cols):
batch_size = getattr(p, 'batch_size', 1)
if batch_size <= 1:
return old_chess(self, p, image, rows, cols)
mask_template, draw_template = self.init_draw(p, image.width, image.height)
# Determine tile "white/black" order
tile_colors = []
for yi in range(rows):
row_colors = []
for xi in range(cols):
color = xi % 2 == 0
if yi > 0 and yi % 2 != 0:
color = not color
row_colors.append(color)
tile_colors.append(row_colors)
# Helper to iterate tiles in chess order: white first, then black
def chess_order_iter(white: bool):
for yi in range(rows):
for xi in range(cols):
if tile_colors[yi][xi] == white:
yield (xi, yi)
# Process white tiles then black tiles
for color in (True, False):
tiles_to_process: List[Tuple[int, int]] = []
for tx, ty in chess_order_iter(color):
if shared.state.interrupted:
break
tiles_to_process.append((tx, ty))
if len(tiles_to_process) >= batch_size:
shared.batch = process_batch_tiles(p, tiles_to_process, shared.batch, self.calc_rectangle)
tiles_to_process = []
if tiles_to_process:
shared.batch = process_batch_tiles(p, tiles_to_process, shared.batch, self.calc_rectangle)
p.width = image.width
p.height = image.height
return image
usdu.USDURedraw.linear_process = new_linear_process
usdu.USDURedraw.chess_process = new_chess_process
patch_usdu_linear_and_chess_process()
logger.info("USDU batch patches applied successfully.")