Spaces:
Running on Zero
Running on Zero
| """Ultimate SD Upscale - tiled upscaling with seam fix.""" | |
| from src.AutoEncoders import VariationalAE | |
| from src.sample import sampling | |
| from src.UltimateSDUpscale import USDU_upscaler, image_util | |
| import torch | |
| from PIL import ImageFilter, ImageDraw, Image | |
| from enum import Enum | |
| import math | |
| state = USDU_upscaler.state | |
| class UnsupportedModel(Exception): | |
| pass | |
| class StableDiffusionProcessing: | |
| """Container for SD processing parameters.""" | |
| def __init__(self, init_img: Image.Image, model, positive, negative, vae, seed, steps, cfg, | |
| sampler_name, scheduler, denoise, upscale_by, uniform_tile_mode, callback=None): | |
| self.init_images = [init_img] | |
| self.image_mask = None | |
| self.mask_blur = 0 | |
| self.inpaint_full_res_padding = 0 | |
| self.width, self.height = init_img.width, init_img.height | |
| self.model, self.positive, self.negative, self.vae = model, positive, negative, vae | |
| self.seed, self.steps, self.cfg = seed, steps, cfg | |
| self.sampler_name, self.scheduler, self.denoise = sampler_name, scheduler, denoise | |
| self.init_size = (init_img.width, init_img.height) | |
| self.upscale_by, self.uniform_tile_mode = upscale_by, uniform_tile_mode | |
| self.extra_generation_params = {} | |
| self.callback = callback | |
| class Processed: | |
| """Container for processed images.""" | |
| def __init__(self, p, images, seed, info): | |
| self.images, self.seed, self.info = images, seed, info | |
| def infotext(self, p, index): | |
| return None | |
| def fix_seed(p): | |
| pass | |
| def process_images(p, pipeline=False): | |
| """Process tiles using inpainting.""" | |
| image_mask = p.image_mask.convert("L") | |
| init_image = p.init_images[0] | |
| crop_region = image_util.get_crop_region(image_mask, p.inpaint_full_res_padding) | |
| x1, y1, x2, y2 = crop_region | |
| crop_ratio = (x2 - x1) / (y2 - y1) | |
| p_ratio = p.width / p.height | |
| if crop_ratio > p_ratio: | |
| target_width, target_height = x2 - x1, round((x2 - x1) / p_ratio) | |
| else: | |
| target_width, target_height = round((y2 - y1) * p_ratio), y2 - y1 | |
| crop_region, _ = image_util.expand_crop(crop_region, image_mask.width, image_mask.height, target_width, target_height) | |
| tile_size = (p.width, p.height) | |
| if p.mask_blur > 0: | |
| image_mask = image_mask.filter(ImageFilter.GaussianBlur(p.mask_blur)) | |
| tiles = [img.crop(crop_region) for img in USDU_upscaler.batch] | |
| initial_tile_size = tiles[0].size | |
| tiles = [t.resize(tile_size, Image.Resampling.LANCZOS) if t.size != tile_size else t for t in tiles] | |
| positive_cropped = image_util.crop_cond(p.positive, crop_region, p.init_size, init_image.size, tile_size) | |
| negative_cropped = image_util.crop_cond(p.negative, crop_region, p.init_size, init_image.size, tile_size) | |
| batched_tiles = torch.cat([image_util.pil_to_tensor(t) for t in tiles], dim=0) | |
| (latent,) = VariationalAE.VAEEncode().encode(p.vae, batched_tiles) | |
| # Auto-detect Flux for disabling multi-scale and setting correct flags | |
| model_sampling_obj = p.model.get_model_object("model_sampling") | |
| from src.sample.sampling import ModelSamplingFlux, ModelSamplingFlux2 | |
| is_flux = isinstance(model_sampling_obj, (ModelSamplingFlux, ModelSamplingFlux2)) | |
| is_flux2 = isinstance(model_sampling_obj, ModelSamplingFlux2) | |
| # Pass crop offsets for positional embedding coherence (Critical for Flux/DiT) | |
| model_options = p.model.model_options.copy() | |
| transformer_options = model_options.get("transformer_options", {}).copy() | |
| transformer_options["top"] = y1 | |
| transformer_options["left"] = x1 | |
| model_options["transformer_options"] = transformer_options | |
| (samples,) = sampling.common_ksampler(p.model, p.seed, p.steps, p.cfg, p.sampler_name, p.scheduler, | |
| positive_cropped, negative_cropped, latent, denoise=p.denoise, | |
| pipeline=pipeline, flux=is_flux, flux2=is_flux2, | |
| model_options=model_options, callback=p.callback) | |
| (decoded,) = VariationalAE.VAEDecode().decode(p.vae, samples) | |
| for i, tile_sampled in enumerate([image_util.tensor_to_pil(decoded, j) for j in range(len(decoded))]): | |
| init_image = USDU_upscaler.batch[i] | |
| if tile_sampled.size != initial_tile_size: | |
| tile_sampled = tile_sampled.resize(initial_tile_size, Image.Resampling.LANCZOS) | |
| image_tile_only = Image.new("RGBA", init_image.size) | |
| image_tile_only.paste(tile_sampled, crop_region[:2]) | |
| temp = image_tile_only.copy() | |
| temp.putalpha(image_mask.resize(temp.size)) | |
| image_tile_only.paste(temp, image_tile_only) | |
| result = init_image.convert("RGBA") | |
| result.alpha_composite(image_tile_only) | |
| USDU_upscaler.batch[i] = result.convert("RGB") | |
| return Processed(p, [USDU_upscaler.batch[0]], p.seed, None) | |
| class USDUMode(Enum): | |
| LINEAR, CHESS, NONE = 0, 1, 2 | |
| class USDUSFMode(Enum): | |
| NONE, BAND_PASS, HALF_TILE, HALF_TILE_PLUS_INTERSECTIONS = 0, 1, 2, 3 | |
| class USDUpscaler: | |
| """Main upscaler class.""" | |
| def __init__(self, p, image, upscaler_index, save_redraw, save_seams_fix, tile_width, tile_height): | |
| self.p, self.image = p, image | |
| self.scale_factor = math.ceil(max(p.width, p.height) / max(image.width, image.height)) | |
| self.upscaler = USDU_upscaler.sd_upscalers[upscaler_index] | |
| self.redraw = USDURedraw() | |
| self.redraw.tile_width = tile_width or tile_height | |
| self.redraw.tile_height = tile_height or tile_width | |
| self.seams_fix = USDUSeamsFix() | |
| self.seams_fix.tile_width = self.redraw.tile_width | |
| self.seams_fix.tile_height = self.redraw.tile_height | |
| self.initial_info = None | |
| self.rows = math.ceil(self.p.height / self.redraw.tile_height) | |
| self.cols = math.ceil(self.p.width / self.redraw.tile_width) | |
| def get_factor(self, num): | |
| if num == 1: return 2 | |
| for f in [4, 3, 2]: | |
| if num % f == 0: return f | |
| return 0 | |
| def get_factors(self): | |
| scales, current = [], 1 | |
| while current < self.scale_factor: | |
| f = self.get_factor(self.scale_factor // current) | |
| scales.append(f) | |
| current *= f | |
| self.scales = enumerate(scales) | |
| def upscale(self): | |
| print(f"Canva: {self.p.width}x{self.p.height}, Image: {self.image.width}x{self.image.height}, Scale: {self.scale_factor}") | |
| self.get_factors() | |
| for idx, val in self.scales: | |
| print(f"Upscaling iteration {idx + 1} with scale factor {val}") | |
| self.image = self.upscaler.scaler.upscale(self.image, val, self.upscaler.data_path) | |
| self.image = self.image.resize((self.p.width, self.p.height), resample=Image.LANCZOS) | |
| def setup_redraw(self, mode, padding, mask_blur): | |
| self.redraw.mode = USDUMode(mode) | |
| self.redraw.enabled = self.redraw.mode != USDUMode.NONE | |
| self.redraw.padding = padding | |
| self.p.mask_blur = mask_blur | |
| def setup_seams_fix(self, padding, denoise, mask_blur, width, mode): | |
| self.seams_fix.padding, self.seams_fix.denoise = padding, denoise | |
| self.seams_fix.mask_blur, self.seams_fix.width = mask_blur, width | |
| self.seams_fix.mode = USDUSFMode(mode) | |
| self.seams_fix.enabled = self.seams_fix.mode != USDUSFMode.NONE | |
| def calc_jobs_count(self): | |
| global state | |
| redraw = (self.rows * self.cols) if self.redraw.enabled else 0 | |
| seams = self.rows * (self.cols - 1) + (self.rows - 1) * self.cols | |
| state.job_count = redraw + seams | |
| def print_info(self): | |
| print(f"Tile: {self.redraw.tile_width}x{self.redraw.tile_height}, Grid: {self.rows}x{self.cols}") | |
| def add_extra_info(self): | |
| self.p.extra_generation_params.update({ | |
| "Ultimate SD upscale upscaler": self.upscaler.name, | |
| "Ultimate SD upscale tile_width": self.redraw.tile_width, | |
| "Ultimate SD upscale tile_height": self.redraw.tile_height, | |
| }) | |
| def process(self, pipeline): | |
| USDU_upscaler.state.begin() | |
| self.calc_jobs_count() | |
| self.result_images = [] | |
| if self.redraw.enabled: | |
| self.image = self.redraw.start(self.p, self.image, self.rows, self.cols, pipeline) | |
| self.initial_info = self.redraw.initial_info | |
| self.result_images.append(self.image) | |
| if self.seams_fix.enabled: | |
| self.image = self.seams_fix.start(self.p, self.image, self.rows, self.cols, pipeline) | |
| self.initial_info = self.seams_fix.initial_info | |
| self.result_images.append(self.image) | |
| USDU_upscaler.state.end() | |
| class USDURedraw: | |
| """Tile redraw functionality.""" | |
| def init_draw(self, p, width, height): | |
| p.inpaint_full_res = True | |
| p.inpaint_full_res_padding = self.padding | |
| p.width = math.ceil((self.tile_width + self.padding) / 64) * 64 | |
| p.height = math.ceil((self.tile_height + self.padding) / 64) * 64 | |
| mask = Image.new("L", (width, height), "black") | |
| return mask, ImageDraw.Draw(mask) | |
| def calc_rectangle(self, xi, yi): | |
| return xi * self.tile_width, yi * self.tile_height, (xi + 1) * self.tile_width, (yi + 1) * self.tile_height | |
| def linear_process(self, p, image, rows, cols, pipeline=False): | |
| global state | |
| mask, draw = self.init_draw(p, image.width, image.height) | |
| for yi in range(rows): | |
| for xi in range(cols): | |
| if state.interrupted: break | |
| draw.rectangle(self.calc_rectangle(xi, yi), fill="white") | |
| p.init_images, p.image_mask = [image], mask | |
| processed = process_images(p, pipeline) | |
| draw.rectangle(self.calc_rectangle(xi, yi), fill="black") | |
| if processed.images: image = processed.images[0] | |
| p.width, p.height = image.width, image.height | |
| self.initial_info = processed.infotext(p, 0) | |
| return image | |
| def start(self, p, image, rows, cols, pipeline=False): | |
| self.initial_info = None | |
| return self.linear_process(p, image, rows, cols, pipeline) | |
| class USDUSeamsFix: | |
| """Seam fixing functionality.""" | |
| def init_draw(self, p): | |
| self.initial_info = None | |
| p.width = math.ceil((self.tile_width + self.padding) / 64) * 64 | |
| p.height = math.ceil((self.tile_height + self.padding) / 64) * 64 | |
| def half_tile_process(self, p, image, rows, cols, pipeline=False): | |
| global state | |
| self.init_draw(p) | |
| processed = None | |
| gradient = Image.linear_gradient("L") | |
| row_gradient = Image.new("L", (self.tile_width, self.tile_height), "black") | |
| row_gradient.paste(gradient.resize((self.tile_width, self.tile_height // 2), Image.BICUBIC), (0, 0)) | |
| row_gradient.paste(gradient.rotate(180).resize((self.tile_width, self.tile_height // 2), Image.BICUBIC), (0, self.tile_height // 2)) | |
| col_gradient = Image.new("L", (self.tile_width, self.tile_height), "black") | |
| col_gradient.paste(gradient.rotate(90).resize((self.tile_width // 2, self.tile_height), Image.BICUBIC), (0, 0)) | |
| col_gradient.paste(gradient.rotate(270).resize((self.tile_width // 2, self.tile_height), Image.BICUBIC), (self.tile_width // 2, 0)) | |
| p.denoising_strength, p.mask_blur = self.denoise, self.mask_blur | |
| for yi in range(rows - 1): | |
| for xi in range(cols): | |
| p.width, p.height = self.tile_width, self.tile_height | |
| p.inpaint_full_res, p.inpaint_full_res_padding = True, self.padding | |
| mask = Image.new("L", (image.width, image.height), "black") | |
| mask.paste(row_gradient, (xi * self.tile_width, yi * self.tile_height + self.tile_height // 2)) | |
| p.init_images, p.image_mask = [image], mask | |
| processed = process_images(p, pipeline) | |
| if processed.images: image = processed.images[0] | |
| for yi in range(rows): | |
| for xi in range(cols - 1): | |
| p.width, p.height = self.tile_width, self.tile_height | |
| p.inpaint_full_res, p.inpaint_full_res_padding = True, self.padding | |
| mask = Image.new("L", (image.width, image.height), "black") | |
| mask.paste(col_gradient, (xi * self.tile_width + self.tile_width // 2, yi * self.tile_height)) | |
| p.init_images, p.image_mask = [image], mask | |
| processed = process_images(p, pipeline) | |
| if processed.images: image = processed.images[0] | |
| p.width, p.height = image.width, image.height | |
| if processed: self.initial_info = processed.infotext(p, 0) | |
| return image | |
| def start(self, p, image, rows, cols, pipeline=False): | |
| return self.half_tile_process(p, image, rows, cols, pipeline) | |
| class Script(USDU_upscaler.Script): | |
| """Main script runner.""" | |
| def run(self, p, _, tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise, | |
| seams_fix_padding, upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image, | |
| seams_fix_mask_blur, seams_fix_type, target_size_type, custom_width, custom_height, custom_scale, pipeline=False): | |
| fix_seed(p) | |
| USDU_upscaler.torch_gc() | |
| p.do_not_save_grid = p.do_not_save_samples = True | |
| p.inpaint_full_res = False | |
| p.inpainting_fill, p.n_iter, p.batch_size = 1, 1, 1 | |
| init_img = image_util.flatten(p.init_images[0], USDU_upscaler.opts.img2img_background_color) | |
| p.width = math.ceil((init_img.width * custom_scale) / 64) * 64 | |
| p.height = math.ceil((init_img.height * custom_scale) / 64) * 64 | |
| upscaler = USDUpscaler(p, init_img, upscaler_index, save_upscaled_image, save_seams_fix_image, tile_width, tile_height) | |
| upscaler.upscale() | |
| upscaler.setup_redraw(redraw_mode, padding, mask_blur) | |
| upscaler.setup_seams_fix(seams_fix_padding, seams_fix_denoise, seams_fix_mask_blur, seams_fix_width, seams_fix_type) | |
| upscaler.print_info() | |
| upscaler.add_extra_info() | |
| upscaler.process(pipeline) | |
| return Processed(p, upscaler.result_images, p.seed, upscaler.initial_info or "") | |
| # Monkey-patch overrides | |
| _old_init = USDUpscaler.__init__ | |
| def _new_init(self, p, image, upscaler_index, save_redraw, save_seams_fix, tile_width, tile_height): | |
| # Determine downscale factor from model (8 for SD, 16 for Flux) | |
| downscale_factor = 8 | |
| try: | |
| latent_format = p.model.get_model_object("latent_format") | |
| if hasattr(latent_format, "downscale_factor"): | |
| downscale_factor = latent_format.downscale_factor | |
| except Exception: | |
| pass | |
| p.width = math.ceil((image.width * p.upscale_by) / downscale_factor) * downscale_factor | |
| p.height = math.ceil((image.height * p.upscale_by) / downscale_factor) * downscale_factor | |
| _old_init(self, p, image, upscaler_index, save_redraw, save_seams_fix, tile_width, tile_height) | |
| USDUpscaler.__init__ = _new_init | |
| _old_redraw = USDURedraw.init_draw | |
| def _new_redraw(self, p, width, height): | |
| mask, draw = _old_redraw(self, p, width, height) | |
| downscale_factor = 8 | |
| try: | |
| latent_format = p.model.get_model_object("latent_format") | |
| if hasattr(latent_format, "downscale_factor"): | |
| downscale_factor = latent_format.downscale_factor | |
| except Exception: | |
| pass | |
| p.width = math.ceil((self.tile_width + self.padding) / downscale_factor) * downscale_factor | |
| p.height = math.ceil((self.tile_height + self.padding) / downscale_factor) * downscale_factor | |
| return mask, draw | |
| USDURedraw.init_draw = _new_redraw | |
| _old_seams = USDUSeamsFix.init_draw | |
| def _new_seams(self, p): | |
| _old_seams(self, p) | |
| downscale_factor = 8 | |
| try: | |
| latent_format = p.model.get_model_object("latent_format") | |
| if hasattr(latent_format, "downscale_factor"): | |
| downscale_factor = latent_format.downscale_factor | |
| except Exception: | |
| pass | |
| p.width = math.ceil((self.tile_width + self.padding) / downscale_factor) * downscale_factor | |
| p.height = math.ceil((self.tile_height + self.padding) / downscale_factor) * downscale_factor | |
| USDUSeamsFix.init_draw = _new_seams | |
| _old_upscale = USDUpscaler.upscale | |
| def _new_upscale(self): | |
| _old_upscale(self) | |
| USDU_upscaler.batch = [self.image] + [img.resize((self.p.width, self.p.height), Image.LANCZOS) for img in USDU_upscaler.batch[1:]] | |
| USDUpscaler.upscale = _new_upscale | |
| MAX_RESOLUTION = 8192 | |
| MODES = {"Linear": USDUMode.LINEAR, "Chess": USDUMode.CHESS, "None": USDUMode.NONE} | |
| SEAM_FIX_MODES = {"None": USDUSFMode.NONE, "Band Pass": USDUSFMode.BAND_PASS, "Half Tile": USDUSFMode.HALF_TILE, "Half Tile + Intersections": USDUSFMode.HALF_TILE_PLUS_INTERSECTIONS} | |
| class UltimateSDUpscale: | |
| """Main entry point for Ultimate SD Upscale.""" | |
| def upscale(self, image, model, positive, negative, vae, upscale_by, seed, steps, cfg, sampler_name, | |
| scheduler, denoise, upscale_model, mode_type, tile_width, tile_height, mask_blur, tile_padding, | |
| seam_fix_mode, seam_fix_denoise, seam_fix_mask_blur, seam_fix_width, seam_fix_padding, | |
| force_uniform_tiles, pipeline=False, callback=None): | |
| USDU_upscaler.sd_upscalers[0] = USDU_upscaler.UpscalerData() | |
| USDU_upscaler.actual_upscaler = upscale_model | |
| USDU_upscaler.batch = [image_util.tensor_to_pil(image, i) for i in range(len(image))] | |
| sdprocessing = StableDiffusionProcessing( | |
| image_util.tensor_to_pil(image), model, positive, negative, vae, seed, steps, cfg, | |
| sampler_name, scheduler, denoise, upscale_by, force_uniform_tiles, callback=callback) | |
| Script().run( | |
| p=sdprocessing, _=None, tile_width=tile_width, tile_height=tile_height, mask_blur=mask_blur, | |
| padding=tile_padding, seams_fix_width=seam_fix_width, seams_fix_denoise=seam_fix_denoise, | |
| seams_fix_padding=seam_fix_padding, upscaler_index=0, save_upscaled_image=False, | |
| redraw_mode=MODES[mode_type], save_seams_fix_image=False, seams_fix_mask_blur=seam_fix_mask_blur, | |
| seams_fix_type=SEAM_FIX_MODES[seam_fix_mode], target_size_type=2, custom_width=None, | |
| custom_height=None, custom_scale=upscale_by, pipeline=pipeline) | |
| return (torch.cat([image_util.pil_to_tensor(img) for img in USDU_upscaler.batch], dim=0),) | |