sdas / sd-webui-progressive-growing /scripts /progressive_growing_always.py
dikdimon's picture
Upload sd-webui-progressive-growing using SD-Hub
66fad01 verified
"""sd-webui-progressive-growing
Always-visible UI + runtime patch for AUTOMATIC1111.
This extension ports the user's provided implementation of `sample_progressive()` 1:1.
It does NOT modify core files on disk; instead it monkey-patches
`modules.processing.StableDiffusionProcessingTxt2Img.sample` at runtime.
UI is AlwaysVisible (not in the Scripts dropdown).
"""
from __future__ import annotations
import gradio as gr
from modules import scripts, sd_samplers, devices
from modules import processing as processing_mod
from modules.processing import create_random_tensors, decode_latent_batch, opt_C, opt_f
# -----------------------------
# Progressive Growing versions
# -----------------------------
def sample_progressive_v1_exact(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
"""Exact copy of the user-provided implementation (processing.py::sample_progressive)."""
import numpy as np
import torch
is_sdxl = getattr(self.sd_model, 'is_sdxl', False)
# 1) Больше НЕТ принудительного min_scale>=0.5 для SDXL:
min_scale = float(self.progressive_growing_min_scale)
max_scale = float(self.progressive_growing_max_scale)
# На всякий случай: если пользователь перепутал местами — делаем честный "рост"
# (если хочешь позволять "shrink", просто убери этот swap)
# if min_scale > max_scale:
# min_scale, max_scale = max_scale, min_scale
resolution_steps = np.linspace(min_scale, max_scale, int(self.progressive_growing_steps))
def _snap(v):
v_int = int(v)
v_int = max(opt_f, v_int)
v_int = (v_int // opt_f) * opt_f
return max(opt_f, v_int)
# 2) Стартовое разрешение
initial_width = _snap(self.width * resolution_steps[0])
initial_height = _snap(self.height * resolution_steps[0])
# 3) Начальный латент (noise)
x = create_random_tensors(
(opt_C, initial_height // opt_f, initial_width // opt_f),
seeds,
subseeds=subseeds,
subseed_strength=subseed_strength,
seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w,
p=self
)
# 4) Первый проход sampler.sample()
samples = self.sampler.sample(
self,
x,
conditioning,
unconditional_conditioning,
image_conditioning=self.txt2img_image_conditioning(x)
)
total_stages = len(resolution_steps)
# 5) Прогрессивный рост
for i in range(1, total_stages):
target_width = _snap(self.width * resolution_steps[i])
target_height = _snap(self.height * resolution_steps[i])
# upscale latent
samples = torch.nn.functional.interpolate(
samples,
size=(target_height // opt_f, target_width // opt_f),
mode='bicubic',
align_corners=False
)
# 6) Refinement на каждом шаге (опционально)
if self.progressive_growing_refinement:
steps_for_refinement = max(1, self.steps // total_stages)
noise = create_random_tensors(
samples.shape[1:],
seeds,
subseeds=subseeds,
subseed_strength=subseed_strength,
seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w,
p=self
)
decoded = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
decoded = torch.stack(decoded).float()
decoded = torch.clamp((decoded + 1.0) / 2.0, 0.0, 1.0)
source_img = decoded * 2.0 - 1.0
self.image_conditioning = self.img2img_image_conditioning(source_img, samples)
samples = self.sampler.sample_img2img(
self,
samples,
noise,
conditioning,
unconditional_conditioning,
steps=steps_for_refinement,
image_conditioning=self.image_conditioning
)
return samples
_VERSIONS = {
"v1 (exact)": sample_progressive_v1_exact,
}
# -----------------------------
# Runtime patching
# -----------------------------
_PATCHED = False
_ORIG_SAMPLE = None
def _apply_patch_once() -> None:
"""Patch StableDiffusionProcessingTxt2Img.sample to route into sample_progressive_* when enabled."""
global _PATCHED, _ORIG_SAMPLE
if _PATCHED:
return
cls = getattr(processing_mod, 'StableDiffusionProcessingTxt2Img', None)
if cls is None:
return
# already patched by us (or another copy)
if getattr(cls, '_progressive_growing_ext_patched', False):
_PATCHED = True
return
_ORIG_SAMPLE = cls.sample
def _sample_wrapper(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
# Only intercept when user enabled the feature
if getattr(self, 'enable_progressive_growing', False):
# mirror the user code: sampler is created at the start of sample()
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
# pick version (defaults to exact v1)
ver = getattr(self, 'progressive_growing_version', 'v1 (exact)')
fn = _VERSIONS.get(ver, sample_progressive_v1_exact)
return fn(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts)
# fallback to original behaviour (including its sampler creation)
return _ORIG_SAMPLE(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts)
cls.sample = _sample_wrapper
cls._progressive_growing_ext_patched = True
_PATCHED = True
# -----------------------------
# Always-visible UI script
# -----------------------------
class ProgressiveGrowingAlwaysVisible(scripts.Script):
def title(self):
return "Progressive Growing"
def show(self, is_img2img):
# Only for txt2img, always visible
return scripts.AlwaysVisible if not is_img2img else False
def ui(self, is_img2img):
with gr.Accordion("Progressive Growing", open=False):
enabled = gr.Checkbox(value=False, label="Enable")
version = gr.Dropdown(choices=list(_VERSIONS.keys()), value="v1 (exact)", label="Version")
min_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.25, label="Min scale")
max_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=1.0, label="Max scale")
steps = gr.Slider(minimum=2, maximum=16, step=1, value=4, label="Stages")
refinement = gr.Checkbox(value=True, label="Refinement between stages")
gr.Markdown(
"- Starts at Min scale, then increases latent resolution up to Max scale.\n"
"- Optional short img2img refinement at each stage.\n"
"- This implementation matches your provided code (v1 exact)."
)
return [enabled, version, min_scale, max_scale, steps, refinement]
def process(self, p, enabled, version, min_scale, max_scale, steps, refinement):
_apply_patch_once()
# store parameters on p (matching the reference implementation's attribute names)
p.enable_progressive_growing = bool(enabled)
p.progressive_growing_version = str(version)
p.progressive_growing_min_scale = float(min_scale)
p.progressive_growing_max_scale = float(max_scale)
p.progressive_growing_steps = int(steps)
p.progressive_growing_refinement = bool(refinement)
if p.enable_progressive_growing:
# Keep generation params so they show up in infotext (if UI/processing prints them)
try:
p.extra_generation_params["Progressive Growing"] = "True"
p.extra_generation_params["Min Scale"] = p.progressive_growing_min_scale
p.extra_generation_params["Max Scale"] = p.progressive_growing_max_scale
p.extra_generation_params["Progressive Growing Steps"] = p.progressive_growing_steps
p.extra_generation_params["Refinement"] = "True" if p.progressive_growing_refinement else None
p.extra_generation_params["PG Version"] = p.progressive_growing_version
except Exception:
# extra_generation_params may not exist in some contexts
pass