"""sd-webui-progressive-growing Always-visible UI + runtime patch for AUTOMATIC1111. This extension ports the user's provided implementation of `sample_progressive()` 1:1. It does NOT modify core files on disk; instead it monkey-patches `modules.processing.StableDiffusionProcessingTxt2Img.sample` at runtime. UI is AlwaysVisible (not in the Scripts dropdown). """ from __future__ import annotations import gradio as gr from modules import scripts, sd_samplers, devices from modules import processing as processing_mod from modules.processing import create_random_tensors, decode_latent_batch, opt_C, opt_f # ----------------------------- # Progressive Growing versions # ----------------------------- def sample_progressive_v1_exact(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): """Exact copy of the user-provided implementation (processing.py::sample_progressive).""" import numpy as np import torch is_sdxl = getattr(self.sd_model, 'is_sdxl', False) # 1) Больше НЕТ принудительного min_scale>=0.5 для SDXL: min_scale = float(self.progressive_growing_min_scale) max_scale = float(self.progressive_growing_max_scale) # На всякий случай: если пользователь перепутал местами — делаем честный "рост" # (если хочешь позволять "shrink", просто убери этот swap) # if min_scale > max_scale: # min_scale, max_scale = max_scale, min_scale resolution_steps = np.linspace(min_scale, max_scale, int(self.progressive_growing_steps)) def _snap(v): v_int = int(v) v_int = max(opt_f, v_int) v_int = (v_int // opt_f) * opt_f return max(opt_f, v_int) # 2) Стартовое разрешение initial_width = _snap(self.width * resolution_steps[0]) initial_height = _snap(self.height * resolution_steps[0]) # 3) Начальный латент (noise) x = create_random_tensors( (opt_C, initial_height // opt_f, initial_width // opt_f), seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self ) # 4) Первый проход sampler.sample() samples = self.sampler.sample( self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x) ) total_stages = len(resolution_steps) # 5) Прогрессивный рост for i in range(1, total_stages): target_width = _snap(self.width * resolution_steps[i]) target_height = _snap(self.height * resolution_steps[i]) # upscale latent samples = torch.nn.functional.interpolate( samples, size=(target_height // opt_f, target_width // opt_f), mode='bicubic', align_corners=False ) # 6) Refinement на каждом шаге (опционально) if self.progressive_growing_refinement: steps_for_refinement = max(1, self.steps // total_stages) noise = create_random_tensors( samples.shape[1:], seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self ) decoded = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) decoded = torch.stack(decoded).float() decoded = torch.clamp((decoded + 1.0) / 2.0, 0.0, 1.0) source_img = decoded * 2.0 - 1.0 self.image_conditioning = self.img2img_image_conditioning(source_img, samples) samples = self.sampler.sample_img2img( self, samples, noise, conditioning, unconditional_conditioning, steps=steps_for_refinement, image_conditioning=self.image_conditioning ) return samples _VERSIONS = { "v1 (exact)": sample_progressive_v1_exact, } # ----------------------------- # Runtime patching # ----------------------------- _PATCHED = False _ORIG_SAMPLE = None def _apply_patch_once() -> None: """Patch StableDiffusionProcessingTxt2Img.sample to route into sample_progressive_* when enabled.""" global _PATCHED, _ORIG_SAMPLE if _PATCHED: return cls = getattr(processing_mod, 'StableDiffusionProcessingTxt2Img', None) if cls is None: return # already patched by us (or another copy) if getattr(cls, '_progressive_growing_ext_patched', False): _PATCHED = True return _ORIG_SAMPLE = cls.sample def _sample_wrapper(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): # Only intercept when user enabled the feature if getattr(self, 'enable_progressive_growing', False): # mirror the user code: sampler is created at the start of sample() self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) # pick version (defaults to exact v1) ver = getattr(self, 'progressive_growing_version', 'v1 (exact)') fn = _VERSIONS.get(ver, sample_progressive_v1_exact) return fn(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts) # fallback to original behaviour (including its sampler creation) return _ORIG_SAMPLE(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts) cls.sample = _sample_wrapper cls._progressive_growing_ext_patched = True _PATCHED = True # ----------------------------- # Always-visible UI script # ----------------------------- class ProgressiveGrowingAlwaysVisible(scripts.Script): def title(self): return "Progressive Growing" def show(self, is_img2img): # Only for txt2img, always visible return scripts.AlwaysVisible if not is_img2img else False def ui(self, is_img2img): with gr.Accordion("Progressive Growing", open=False): enabled = gr.Checkbox(value=False, label="Enable") version = gr.Dropdown(choices=list(_VERSIONS.keys()), value="v1 (exact)", label="Version") min_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.25, label="Min scale") max_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=1.0, label="Max scale") steps = gr.Slider(minimum=2, maximum=16, step=1, value=4, label="Stages") refinement = gr.Checkbox(value=True, label="Refinement between stages") gr.Markdown( "- Starts at Min scale, then increases latent resolution up to Max scale.\n" "- Optional short img2img refinement at each stage.\n" "- This implementation matches your provided code (v1 exact)." ) return [enabled, version, min_scale, max_scale, steps, refinement] def process(self, p, enabled, version, min_scale, max_scale, steps, refinement): _apply_patch_once() # store parameters on p (matching the reference implementation's attribute names) p.enable_progressive_growing = bool(enabled) p.progressive_growing_version = str(version) p.progressive_growing_min_scale = float(min_scale) p.progressive_growing_max_scale = float(max_scale) p.progressive_growing_steps = int(steps) p.progressive_growing_refinement = bool(refinement) if p.enable_progressive_growing: # Keep generation params so they show up in infotext (if UI/processing prints them) try: p.extra_generation_params["Progressive Growing"] = "True" p.extra_generation_params["Min Scale"] = p.progressive_growing_min_scale p.extra_generation_params["Max Scale"] = p.progressive_growing_max_scale p.extra_generation_params["Progressive Growing Steps"] = p.progressive_growing_steps p.extra_generation_params["Refinement"] = "True" if p.progressive_growing_refinement else None p.extra_generation_params["PG Version"] = p.progressive_growing_version except Exception: # extra_generation_params may not exist in some contexts pass